simple-agents-core 0.2.34

Core SimpleAgents client API integrating providers, routing, caching, and healing
Documentation
use async_trait::async_trait;
use simple_agent_type::prelude::*;
use simple_agents_cache::InMemoryCache;
use simple_agents_core::{
    CompletionMode, CompletionOptions, CompletionOutcome, HealingSettings, Middleware, RoutingMode,
    SimpleAgentsClientBuilder,
};
use simple_agents_healing::schema::Schema;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
use std::time::Duration;

struct MockProvider {
    name: &'static str,
    content: &'static str,
    calls: AtomicUsize,
}

impl MockProvider {
    fn new(name: &'static str, content: &'static str) -> Self {
        Self {
            name,
            content,
            calls: AtomicUsize::new(0),
        }
    }

    fn calls(&self) -> usize {
        self.calls.load(Ordering::Relaxed)
    }
}

#[async_trait]
impl Provider for MockProvider {
    fn name(&self) -> &str {
        self.name
    }

    fn transform_request(&self, _req: &CompletionRequest) -> Result<ProviderRequest> {
        Ok(ProviderRequest::new("http://example.com"))
    }

    async fn execute(&self, _req: ProviderRequest) -> Result<ProviderResponse> {
        self.calls.fetch_add(1, Ordering::Relaxed);
        Ok(ProviderResponse::new(200, serde_json::Value::Null))
    }

    fn transform_response(&self, _resp: ProviderResponse) -> Result<CompletionResponse> {
        Ok(CompletionResponse {
            id: "resp_test".to_string(),
            model: "test-model".to_string(),
            choices: vec![CompletionChoice {
                index: 0,
                message: Message::assistant(self.content),
                finish_reason: FinishReason::Stop,
                logprobs: None,
            }],
            usage: Usage::new(1, 1),
            created: None,
            provider: Some(self.name.to_string()),
            healing_metadata: None,
        })
    }
}

struct TrackingMiddleware {
    before: AtomicUsize,
    after: AtomicUsize,
    cache_hits: AtomicUsize,
    errors: AtomicUsize,
    last_latency_ms: Mutex<Option<u128>>,
}

impl TrackingMiddleware {
    fn new() -> Self {
        Self {
            before: AtomicUsize::new(0),
            after: AtomicUsize::new(0),
            cache_hits: AtomicUsize::new(0),
            errors: AtomicUsize::new(0),
            last_latency_ms: Mutex::new(None),
        }
    }
}

#[async_trait]
impl Middleware for TrackingMiddleware {
    async fn before_request(&self, _request: &CompletionRequest) -> Result<()> {
        self.before.fetch_add(1, Ordering::Relaxed);
        Ok(())
    }

    async fn after_response(
        &self,
        _request: &CompletionRequest,
        _response: &CompletionResponse,
        latency: Duration,
    ) -> Result<()> {
        self.after.fetch_add(1, Ordering::Relaxed);
        let mut lock = self
            .last_latency_ms
            .lock()
            .map_err(|_| SimpleAgentsError::Config("latency lock poisoned".to_string()))?;
        *lock = Some(latency.as_millis());
        Ok(())
    }

    async fn on_cache_hit(
        &self,
        _request: &CompletionRequest,
        _response: &CompletionResponse,
    ) -> Result<()> {
        self.cache_hits.fetch_add(1, Ordering::Relaxed);
        Ok(())
    }

    async fn on_error(
        &self,
        _request: &CompletionRequest,
        _error: &SimpleAgentsError,
        _latency: Duration,
    ) -> Result<()> {
        self.errors.fetch_add(1, Ordering::Relaxed);
        Ok(())
    }
}

fn build_request() -> CompletionRequest {
    CompletionRequest::builder()
        .model("test-model")
        .message(Message::user("hello"))
        .build()
        .unwrap()
}

#[tokio::test]
async fn complete_uses_cache() {
    let provider = Arc::new(MockProvider::new("p1", "ok"));
    let cache = Arc::new(InMemoryCache::new(1024 * 1024, 10));

    let client = SimpleAgentsClientBuilder::new()
        .with_provider(provider.clone())
        .with_cache(cache)
        .with_cache_ttl(Duration::from_secs(60))
        .build()
        .unwrap();

    let request = build_request();
    let first = client
        .complete(&request, CompletionOptions::default())
        .await
        .unwrap();
    let second = client
        .complete(&request, CompletionOptions::default())
        .await
        .unwrap();

    let first = match first {
        CompletionOutcome::Response(response) => response,
        _ => panic!("expected response"),
    };
    let second = match second {
        CompletionOutcome::Response(response) => response,
        _ => panic!("expected response"),
    };

    assert_eq!(first.content(), Some("ok"));
    assert_eq!(second.content(), Some("ok"));
    assert_eq!(provider.calls(), 1);
}

#[tokio::test]
async fn complete_json_parses_markdown() {
    let provider = Arc::new(MockProvider::new("p1", "```json\n{\"value\": 42,}\n```"));
    let client = SimpleAgentsClientBuilder::new()
        .with_provider(provider)
        .with_healing_settings(HealingSettings::default())
        .build()
        .unwrap();

    let healed = client
        .complete(
            &build_request(),
            CompletionOptions {
                mode: CompletionMode::HealedJson,
            },
        )
        .await
        .unwrap();
    let healed = match healed {
        CompletionOutcome::HealedJson(healed) => healed,
        _ => panic!("expected healed json"),
    };
    assert_eq!(healed.parsed.value["value"], 42);
}

#[tokio::test]
async fn complete_with_schema_coerces_types() {
    let provider = Arc::new(MockProvider::new("p1", "{\"count\": \"5\"}"));
    let client = SimpleAgentsClientBuilder::new()
        .with_provider(provider)
        .with_healing_settings(HealingSettings::default())
        .build()
        .unwrap();

    let schema = Schema::object(vec![("count".into(), Schema::Int, true)]);
    let healed = client
        .complete(
            &build_request(),
            CompletionOptions {
                mode: CompletionMode::CoercedSchema(schema),
            },
        )
        .await
        .unwrap();
    let healed = match healed {
        CompletionOutcome::CoercedSchema(healed) => healed,
        _ => panic!("expected coerced schema"),
    };

    assert_eq!(healed.coerced.value["count"], 5);
}

#[tokio::test]
async fn middleware_hooks_fire() {
    let provider = Arc::new(MockProvider::new("p1", "ok"));
    let middleware = Arc::new(TrackingMiddleware::new());

    let client = SimpleAgentsClientBuilder::new()
        .with_provider(provider)
        .with_routing_mode(RoutingMode::RoundRobin)
        .with_middleware(middleware.clone())
        .build()
        .unwrap();

    let _ = client
        .complete(&build_request(), CompletionOptions::default())
        .await
        .unwrap();

    assert_eq!(middleware.before.load(Ordering::Relaxed), 1);
    assert_eq!(middleware.after.load(Ordering::Relaxed), 1);
    assert_eq!(middleware.cache_hits.load(Ordering::Relaxed), 0);
}