cortex-agent 0.2.1

Self-learning AI agent with persistent memory, tools, plugins, and a beautiful terminal UI
use serde::{Deserialize, Serialize};

/// A tool invocation requested by the LLM.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
    pub id: String,
    pub name: String,
    #[serde(default)]
    pub arguments: serde_json::Value,
}

impl ToolCall {
    pub fn to_api_dict(&self) -> serde_json::Value {
        serde_json::json!({
            "id": self.id,
            "type": "function",
            "function": {
                "name": self.name,
                "arguments": serde_json::to_string(&self.arguments).unwrap_or_default(),
            },
        })
    }

    pub fn from_api_dict(d: &serde_json::Value) -> Self {
        let func = &d["function"];
        let args = func["arguments"]
            .as_str()
            .and_then(|s| serde_json::from_str(s).ok())
            .unwrap_or(serde_json::Value::Object(Default::default()));
        Self {
            id: d["id"].as_str().unwrap_or("").to_string(),
            name: func["name"].as_str().unwrap_or("").to_string(),
            arguments: args,
        }
    }
}

/// A single message in the conversation.
#[derive(Debug, Clone)]
pub struct Message {
    pub role: String,
    pub content: Option<String>,
    pub tool_calls: Option<Vec<ToolCall>>,
    pub tool_call_id: Option<String>,
    pub name: Option<String>,
}

impl Usage {
    /// Estimate the cost of this usage based on model pricing.
    /// Returns the cost in USD.
    pub fn estimated_cost(&self, model: &str) -> f64 {
        let (input_rate, output_rate) = if model.contains("deepseek") {
            if model.contains("flash") || model.contains("free") {
                (0.0, 0.0) // free tier
            } else {
                (0.14e-6, 0.28e-6)
            }
        } else if model.contains("llama") || model.contains("nemotron") {
            (0.10e-6, 0.20e-6)
        } else if model.contains("gpt-4o") {
            (2.50e-6, 10.0e-6)
        } else if model.contains("gpt-4") || model.contains("claude") {
            (10.0e-6, 30.0e-6)
        } else if model.contains("mixtral") || model.contains("gemma") {
            (0.60e-6, 0.60e-6)
        } else if model.contains("gemini") {
            (0.50e-6, 1.50e-6)
        } else if model.contains("mistral") {
            (2.0e-6, 6.0e-6)
        } else {
            (0.50e-6, 1.50e-6)
        };
        self.prompt_tokens as f64 * input_rate + self.completion_tokens as f64 * output_rate
    }
}

/// Token usage tracked across turns with per-turn breakdown.
#[derive(Debug, Clone, Default)]
pub struct Usage {
    // Cumulative session totals
    pub prompt_tokens: u64,
    pub completion_tokens: u64,
    pub total_tokens: u64,
    pub session_prompt: u64,
    pub session_completion: u64,
    pub session_total: u64,

    // Per-turn tracking (for the most recent turn)
    pub turn_prompt: u64,
    pub turn_completion: u64,
    pub turn_total: u64,
}

impl Usage {
    /// Record a turn's usage, updating both per-turn and cumulative totals.
    pub fn record_turn(&mut self, prompt: u64, completion: u64) {
        self.turn_prompt = prompt;
        self.turn_completion = completion;
        self.turn_total = prompt + completion;

        self.prompt_tokens += prompt;
        self.completion_tokens += completion;
        self.total_tokens += prompt + completion;
        self.session_prompt = self.prompt_tokens;
        self.session_completion = self.completion_tokens;
        self.session_total = self.total_tokens;
    }

    /// Get per-turn estimated cost for the current model.
    #[allow(dead_code)]
    pub fn turn_cost(&self, model: &str) -> f64 {
        let usage = Usage {
            prompt_tokens: self.turn_prompt,
            completion_tokens: self.turn_completion,
            total_tokens: self.turn_total,
            ..Default::default()
        };
        usage.estimated_cost(model)
    }

    /// Get total session estimated cost.
    pub fn session_cost(&self, model: &str) -> f64 {
        self.estimated_cost(model)
    }
}

impl Message {
    pub fn new_system(content: impl Into<String>) -> Self {
        Self {
            role: "system".into(),
            content: Some(content.into()),
            tool_calls: None,
            tool_call_id: None,
            name: None,
        }
    }

    pub fn new_user(content: impl Into<String>) -> Self {
        Self {
            role: "user".into(),
            content: Some(content.into()),
            tool_calls: None,
            tool_call_id: None,
            name: None,
        }
    }

    pub fn new_assistant(content: Option<String>, tool_calls: Option<Vec<ToolCall>>) -> Self {
        Self {
            role: "assistant".into(),
            content,
            tool_calls,
            tool_call_id: None,
            name: None,
        }
    }

    pub fn new_tool(content: impl Into<String>, tool_call_id: impl Into<String>, name: impl Into<String>) -> Self {
        Self {
            role: "tool".into(),
            content: Some(content.into()),
            tool_calls: None,
            tool_call_id: Some(tool_call_id.into()),
            name: Some(name.into()),
        }
    }

    /// Serialize to OpenAI API format.
    pub fn to_api_dict(&self) -> serde_json::Value {
        let mut d = serde_json::json!({
            "role": self.role,
        });
        // OpenAI API requires explicit null for content when tool_calls are present
        if let Some(ref content) = self.content {
            d["content"] = serde_json::Value::String(content.clone());
        } else {
            d["content"] = serde_json::Value::Null;
        }
        if let Some(ref calls) = self.tool_calls {
            d["tool_calls"] = calls.iter().map(|tc| tc.to_api_dict()).collect();
        }
        if let Some(ref id) = self.tool_call_id {
            d["tool_call_id"] = serde_json::Value::String(id.clone());
        }
        if let Some(ref name) = self.name {
            d["name"] = serde_json::Value::String(name.clone());
        }
        d
    }

    /// Deserialize from OpenAI API response.
    pub fn from_api_dict(d: &serde_json::Value) -> Self {
        let tool_calls = d["tool_calls"]
            .as_array()
            .filter(|arr| !arr.is_empty())
            .map(|arr| arr.iter().map(ToolCall::from_api_dict).collect());

        Self {
            role: d["role"].as_str().unwrap_or("user").to_string(),
            content: d.get("content").and_then(|c| {
                if c.is_null() { None } else { c.as_str().map(String::from) }
            }),
            tool_calls,
            tool_call_id: d.get("tool_call_id").and_then(|v| v.as_str().map(String::from)),
            name: d.get("name").and_then(|v| v.as_str().map(String::from)),
        }
    }

    #[allow(dead_code)]
    pub fn pretty_print(&self) -> String {
        let mut out = format!("[{}]", self.role.to_uppercase());
        if let Some(ref content) = self.content {
            if content.len() > 2000 {
                out.push_str(&content[..2000]);
                out.push_str("\n... [truncated]");
            } else {
                out.push_str(content);
            }
        }
        if let Some(ref calls) = self.tool_calls {
            for tc in calls {
                out.push_str(&format!("\n  -> tool: {}({:?})", tc.name, tc.arguments));
            }
        }
        out
    }
}