cortex-agent 0.2.1

Self-learning AI agent with persistent memory, tools, plugins, and a beautiful terminal UI
use std::collections::HashMap;
use std::pin::Pin;
use std::sync::{Arc, Mutex};

use async_trait::async_trait;
use bytes::Bytes;
use futures::Stream;
use futures_util::StreamExt;
use reqwest::Client;
use tokio::sync::mpsc;
use tokio_stream::wrappers::UnboundedReceiverStream;

use crate::messages::{Message, ToolCall};
use crate::provider::{Provider, ProviderError};

/// OpenAI-compatible LLM provider with streaming.
pub struct OpenAICompatibleProvider {
    model: String,
    api_key: String,
    base_url: String,
    max_retries: u32,
    timeout: u64,
    client: Client,
    pub last_stream_message: Arc<Mutex<Option<Message>>>,
    pub last_usage: Arc<std::sync::Mutex<Option<crate::messages::Usage>>>,
}

impl OpenAICompatibleProvider {
    pub fn new(
        model: String,
        api_key: String,
        base_url: String,
        max_retries: u32,
        timeout: u64,
    ) -> Self {
        let client = Client::builder()
            .timeout(std::time::Duration::from_secs(timeout))
            .build()
            .unwrap_or_default();

        Self {
            model,
            api_key,
            base_url,
            max_retries,
            timeout,
            client,
            last_stream_message: Arc::new(Mutex::new(None)),
            last_usage: Arc::new(std::sync::Mutex::new(None)),
        }
    }

    fn build_body(
        &self,
        messages: &[Message],
        tools: Option<&[serde_json::Value]>,
        tool_choice: &str,
        max_tokens: Option<u32>,
        temperature: f32,
        stream: bool,
    ) -> serde_json::Value {
        let mut body = serde_json::json!({
            "model": self.model,
            "messages": messages.iter().map(|m| m.to_api_dict()).collect::<Vec<_>>(),
            "temperature": temperature,
            "tool_choice": tool_choice,
        });
        if let Some(t) = tools {
            body["tools"] = serde_json::Value::Array(t.to_vec());
        }
        if let Some(mt) = max_tokens {
            body["max_tokens"] = serde_json::Value::Number(serde_json::Number::from(mt));
        }
        if stream {
            body["stream"] = serde_json::Value::Bool(true);
        }
        body
    }

    fn headers(&self) -> reqwest::header::HeaderMap {
        let mut headers = reqwest::header::HeaderMap::new();
        headers.insert(
            reqwest::header::AUTHORIZATION,
            format!("Bearer {}", self.api_key).parse().unwrap(),
        );
        headers.insert(
            reqwest::header::CONTENT_TYPE,
            "application/json".parse().unwrap(),
        );
        headers
    }

    fn url(&self) -> String {
        format!("{}/chat/completions", self.base_url.trim_end_matches('/'))
    }
}

#[async_trait]
impl Provider for OpenAICompatibleProvider {
    async fn chat_completion(
        &self,
        messages: &[Message],
        tools: Option<&[serde_json::Value]>,
        tool_choice: &str,
        max_tokens: Option<u32>,
        temperature: f32,
    ) -> Result<Message, ProviderError> {
        let body = self.build_body(messages, tools, tool_choice, max_tokens, temperature, false);
        let mut last_error: Option<ProviderError> = None;

        for attempt in 0..self.max_retries {
            match self.client.post(&self.url()).headers(self.headers()).json(&body).send().await {
                Ok(resp) => {
                    if resp.status().is_success() {
                        let data: serde_json::Value = resp.json().await.map_err(|e| ProviderError::Other(format!("JSON parse error: {}", e)))?;
                        let choice = &data["choices"][0]["message"];
                        // Extract usage if available
                        if let Some(usage) = data.get("usage") {
                            let u = crate::messages::Usage {
                                prompt_tokens: usage.get("prompt_tokens").and_then(|v| v.as_u64()).unwrap_or(0),
                                completion_tokens: usage.get("completion_tokens").and_then(|v| v.as_u64()).unwrap_or(0),
                                total_tokens: usage.get("total_tokens").and_then(|v| v.as_u64()).unwrap_or(0),
                                ..Default::default()
                            };
                            if let Ok(mut last) = self.last_usage.lock() { *last = Some(u); }
                        }
                        return Ok(Message::from_api_dict(choice));
                    }
                    let status = resp.status().as_u16();
                    let text = resp.text().await.unwrap_or_default();
                    if status == 429 || (500..=599).contains(&status) {
                        last_error = Some(ProviderError::Api { status, body: text.clone() });
                        if attempt < self.max_retries - 1 { tokio::time::sleep(std::time::Duration::from_secs(2u64.pow(attempt))).await; continue; }
                    }
                    return Err(ProviderError::Api { status, body: text });
                }
                Err(e) if e.is_timeout() => {
                    last_error = Some(ProviderError::Timeout(format!("Request timed out ({}s)", self.timeout)));
                    if attempt < self.max_retries - 1 { tokio::time::sleep(std::time::Duration::from_secs(2u64.pow(attempt))).await; continue; }
                }
                Err(e) => {
                    last_error = Some(ProviderError::Http(format!("Request failed: {}", e)));
                    if attempt < self.max_retries - 1 { tokio::time::sleep(std::time::Duration::from_secs(2u64.pow(attempt))).await; continue; }
                }
            }
        }
        Err(last_error.unwrap_or_else(|| ProviderError::Other("Provider call failed after all retries".into())))
    }

    async fn chat_completion_stream(
        &self,
        messages: &[Message],
        tools: Option<&[serde_json::Value]>,
        tool_choice: &str,
        max_tokens: Option<u32>,
        temperature: f32,
    ) -> Result<Pin<Box<dyn Stream<Item = Result<String, ProviderError>> + Send>>, ProviderError> {
        let body = self.build_body(messages, tools, tool_choice, max_tokens, temperature, true);
        let (tx, rx) = mpsc::unbounded_channel();

        let client = self.client.clone();
        let url = self.url();
        let headers = self.headers();
        let timeout_dur = std::time::Duration::from_secs(self.timeout);
        let last_stream = self.last_stream_message.clone();
        let last_stream_usage = self.last_usage.clone();

        tokio::spawn(async move {
            let result = client.post(&url).headers(headers).json(&body).timeout(timeout_dur).send().await;
            let response = match result {
                Ok(r) => r,
                Err(e) => { let _ = tx.send(Err(ProviderError::Stream(format!("Request failed: {}", e)))); return; }
            };

            if !response.status().is_success() {
                let status = response.status().as_u16();
                let text = response.text().await.unwrap_or_default();
                let _ = tx.send(Err(ProviderError::Api { status, body: text }));
                return;
            };

            let mut content_parts: Vec<String> = Vec::new();
            let mut tool_call_map: HashMap<usize, ToolCallBuilder> = HashMap::new();
            let mut buffer = String::new();

            let mut stream = response.bytes_stream();
            while let Some(chunk_result) = stream.next().await {
                let chunk: Bytes = match chunk_result {
                    Ok(c) => c,
                    Err(e) => { let _ = tx.send(Err(ProviderError::Stream(format!("Read error: {}", e)))); return; }
                };
                buffer.push_str(&String::from_utf8_lossy(&chunk));
                while let Some(newline) = buffer.find('\n') {
                    let line = buffer[..newline].to_string();
                    buffer = buffer[newline + 1..].to_string();
                    let line = line.trim().to_string();
                    if !line.starts_with("data: ") { continue; }
                    let payload = line[6..].trim().to_string();
                    if payload == "[DONE]" { break; }
                    let chunk_value: serde_json::Value = match serde_json::from_str(&payload) { Ok(v) => v, Err(_) => continue };
                    let choices = &chunk_value["choices"];
                    if choices.as_array().map_or(true, |a| a.is_empty()) {
                        // Empty choices — might contain usage info (final chunk)
                        if let Some(usage) = chunk_value.get("usage") {
                            let u = crate::messages::Usage {
                                prompt_tokens: usage.get("prompt_tokens").and_then(|v| v.as_u64()).unwrap_or(0),
                                completion_tokens: usage.get("completion_tokens").and_then(|v| v.as_u64()).unwrap_or(0),
                                total_tokens: usage.get("total_tokens").and_then(|v| v.as_u64()).unwrap_or(0),
                                ..Default::default()
                            };
                            if let Ok(mut last) = last_stream_usage.lock() { *last = Some(u); }
                        }
                        continue;
                    }
                    let delta = &choices[0]["delta"];
                    if let Some(content) = delta["content"].as_str() {
                        if !content.is_empty() { content_parts.push(content.to_string()); let _ = tx.send(Ok(content.to_string())); }
                    }
                    if let Some(tc_deltas) = delta["tool_calls"].as_array() {
                        for tc in tc_deltas {
                            let idx = tc["index"].as_u64().unwrap_or(0) as usize;
                            let builder = tool_call_map.entry(idx).or_insert_with(|| ToolCallBuilder { id: String::new(), name: String::new(), arguments: String::new() });
                            if let Some(id) = tc["id"].as_str() { if !id.is_empty() { builder.id = id.to_string(); } }
                            if let Some(name) = tc["function"]["name"].as_str() { builder.name.push_str(name); }
                            if let Some(args) = tc["function"]["arguments"].as_str() { builder.arguments.push_str(args); }
                        }
                    }
                }
            }

            // Build the final Message
            let content = if content_parts.is_empty() { None } else { Some(content_parts.concat()) };
            let tool_calls: Option<Vec<ToolCall>> = if tool_call_map.is_empty() {
                None
            } else {
                let mut calls: Vec<ToolCall> = tool_call_map.into_iter().map(|(_, b)| {
                    let args = serde_json::from_str(&b.arguments).unwrap_or(serde_json::Value::Object(Default::default()));
                    ToolCall { id: b.id, name: b.name, arguments: args }
                }).collect();
                calls.sort_by(|a, b| a.id.cmp(&b.id));
                Some(calls)
            };

            let msg = Message::new_assistant(content, tool_calls);
            if let Ok(mut last) = last_stream.lock() { *last = Some(msg); }
        });

        let stream = UnboundedReceiverStream::new(rx);
        Ok(Box::pin(stream))
    }

    fn last_stream_message(&self) -> Option<Message> {
        self.last_stream_message.lock().ok().and_then(|mut guard| guard.take())
    }

    fn last_usage(&self) -> Option<crate::messages::Usage> {
        self.last_usage.lock().ok().and_then(|mut guard| guard.take())
    }

    /// Call the /v1/embeddings endpoint (OpenAI-compatible).
    async fn embed(&self, text: &str) -> Option<Vec<f32>> {
        let url = format!("{}embeddings", self.base_url.trim_end_matches('/'));
        let body = serde_json::json!({
            "model": "text-embedding-3-small",
            "input": text
        });
        let resp = self
            .client
            .post(&url)
            .header("Authorization", format!("Bearer {}", self.api_key))
            .header("Content-Type", "application/json")
            .json(&body)
            .send()
            .await
            .ok()?;
        if !resp.status().is_success() {
            return None;
        }
        let data: serde_json::Value = resp.json().await.ok()?;
        let embedding = data["data"][0]["embedding"].as_array()?;
        let vec: Vec<f32> = embedding
            .iter()
            .filter_map(|v| v.as_f64().map(|f| f as f32))
            .collect();
        if vec.is_empty() { None } else { Some(vec) }
    }
}

struct ToolCallBuilder {
    id: String,
    name: String,
    arguments: String,
}

/// Factory function.
pub fn create_provider(
    provider_type: &str,
    model: &str,
    api_key: &str,
    base_url: Option<&str>,
) -> Result<Box<dyn Provider>, String> {
    match provider_type.to_lowercase().as_str() {
        "openai" | "openai-compatible" => Ok(Box::new(OpenAICompatibleProvider::new(
            model.to_string(), api_key.to_string(), base_url.unwrap_or("https://api.openai.com/v1").to_string(), 3, 120,
        ))),
        _ => Err(format!("Unknown provider type: '{}'. Supported: openai.", provider_type)),
    }
}