cortex-agent 0.2.1

Self-learning AI agent with persistent memory, tools, plugins, and a beautiful terminal UI
use std::pin::Pin;
use std::sync::Arc;

use async_trait::async_trait;
use futures::Stream;

use crate::messages::Message;

/// Provider error types.
#[derive(Debug, thiserror::Error)]
pub enum ProviderError {
    #[error("HTTP request failed: {0}")]
    Http(String),
    #[error("API error {status}: {body}")]
    Api { status: u16, body: String },
    #[error("Stream error: {0}")]
    Stream(String),
    #[error("Timeout: {0}")]
    Timeout(String),
    #[error("{0}")]
    Other(String),
}

/// Abstract interface for LLM providers.
#[async_trait]
pub trait Provider: Send + Sync {
    async fn chat_completion(
        &self,
        messages: &[Message],
        tools: Option<&[serde_json::Value]>,
        tool_choice: &str,
        max_tokens: Option<u32>,
        temperature: f32,
    ) -> Result<Message, ProviderError>;

    /// Stream a chat completion, yielding text tokens.
    async fn chat_completion_stream(
        &self,
        messages: &[Message],
        tools: Option<&[serde_json::Value]>,
        tool_choice: &str,
        max_tokens: Option<u32>,
        temperature: f32,
    ) -> Result<Pin<Box<dyn Stream<Item = Result<String, ProviderError>> + Send>>, ProviderError>;

    /// Get the last completed stream message (with tool_calls if any).
    fn last_stream_message(&self) -> Option<Message>;

    /// Get usage info from the last completion (if available).
    fn last_usage(&self) -> Option<crate::messages::Usage>;

    /// Embed a text string into a vector. Returns None if not supported.
    async fn embed(&self, _text: &str) -> Option<Vec<f32>> {
        None
    }
}

/// Blanket impl so `Box<dyn Provider>` can be used as `&dyn Provider`.
#[async_trait]
impl<T: Provider + ?Sized> Provider for Box<T> {
    async fn chat_completion(
        &self,
        messages: &[Message],
        tools: Option<&[serde_json::Value]>,
        tool_choice: &str,
        max_tokens: Option<u32>,
        temperature: f32,
    ) -> Result<Message, ProviderError> {
        (**self).chat_completion(messages, tools, tool_choice, max_tokens, temperature).await
    }

    async fn chat_completion_stream(
        &self,
        messages: &[Message],
        tools: Option<&[serde_json::Value]>,
        tool_choice: &str,
        max_tokens: Option<u32>,
        temperature: f32,
    ) -> Result<Pin<Box<dyn Stream<Item = Result<String, ProviderError>> + Send>>, ProviderError> {
        (**self).chat_completion_stream(messages, tools, tool_choice, max_tokens, temperature).await
    }

    fn last_stream_message(&self) -> Option<Message> {
        (**self).last_stream_message()
    }

    fn last_usage(&self) -> Option<crate::messages::Usage> {
        (**self).last_usage()
    }
}

#[async_trait]
impl<T: Provider + ?Sized> Provider for Arc<T> {
    async fn chat_completion(
        &self,
        messages: &[Message],
        tools: Option<&[serde_json::Value]>,
        tool_choice: &str,
        max_tokens: Option<u32>,
        temperature: f32,
    ) -> Result<Message, ProviderError> {
        (**self).chat_completion(messages, tools, tool_choice, max_tokens, temperature).await
    }

    async fn chat_completion_stream(
        &self,
        messages: &[Message],
        tools: Option<&[serde_json::Value]>,
        tool_choice: &str,
        max_tokens: Option<u32>,
        temperature: f32,
    ) -> Result<Pin<Box<dyn Stream<Item = Result<String, ProviderError>> + Send>>, ProviderError> {
        (**self).chat_completion_stream(messages, tools, tool_choice, max_tokens, temperature).await
    }

    fn last_stream_message(&self) -> Option<Message> {
        (**self).last_stream_message()
    }

    fn last_usage(&self) -> Option<crate::messages::Usage> {
        (**self).last_usage()
    }
}