meme 0.7.0

Long term memory for AI agents.
Documentation
//! OpenAI-compatible async LLM client.

use std::time::Duration;

use serde::{Deserialize, Serialize};

use crate::config::LlmConfig;
use crate::error::{MemeError, Result};
use crate::llm::json::extract_json_from_text;

/// Chat message role.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub(crate) enum Role {
    /// System prompt.
    System,
    /// User message.
    User,
    /// Assistant response.
    Assistant,
}

/// A single chat message.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) struct Message {
    /// Role of the message sender.
    pub role: Role,
    /// Message content.
    pub content: String,
}

impl Message {
    /// Create a system message.
    #[must_use]
    pub(crate) fn system(content: impl Into<String>) -> Self {
        Self {
            role: Role::System,
            content: content.into(),
        }
    }

    /// Create a user message.
    #[must_use]
    pub(crate) fn user(content: impl Into<String>) -> Self {
        Self {
            role: Role::User,
            content: content.into(),
        }
    }
}

/// Options for a chat completion request.
#[derive(Debug, Clone, Copy)]
pub(crate) struct ChatOptions {
    /// Temperature for generation.
    pub temperature: f32,
    /// Whether to request JSON output format.
    pub json_mode: bool,
}

impl Default for ChatOptions {
    fn default() -> Self {
        Self {
            temperature: 0.1,
            json_mode: true,
        }
    }
}

/// OpenAI-compatible HTTP LLM client.
#[derive(Debug, Clone)]
pub(crate) struct LlmClient {
    /// Shared HTTP client.
    http: reqwest::Client,
    /// API base URL.
    base_url: String,
    /// Bearer token.
    api_key: String,
    /// Model identifier.
    model: String,
    /// Maximum retry attempts on transient errors.
    max_retries: u32,
}

impl LlmClient {
    /// Create a new client from configuration using a shared HTTP client.
    ///
    /// # Errors
    ///
    /// Returns an error if the API key is missing.
    pub(crate) fn new(http: reqwest::Client, config: &LlmConfig) -> Result<Self> {
        let api_key = config
            .api_key
            .clone()
            .ok_or_else(|| MemeError::Config("LLM API key is required".to_owned()))?;

        Ok(Self {
            http,
            base_url: config.base_url.trim_end_matches('/').to_owned(),
            api_key,
            model: config.model.clone(),
            max_retries: config.max_retries,
        })
    }

    /// Send a chat completion and deserialize the response into `T`.
    ///
    /// Uses `json_object` response format + `serde_json::from_str` for type-safe parsing.
    /// Retries with exponential backoff on transient or parse failures.
    ///
    /// # Errors
    ///
    /// Returns an error if the API call fails after retries or the response
    /// cannot be deserialized into `T`.
    #[tracing::instrument(skip(self, messages, opts), fields(model = %self.model))]
    pub(crate) async fn chat_structured<T: serde::de::DeserializeOwned>(
        &self,
        messages: &[Message],
        opts: &ChatOptions,
    ) -> Result<T> {
        let mut last_err = None;
        for attempt in 0..self.max_retries {
            match self.call_api(messages, opts).await {
                Ok(content) => match serde_json::from_str::<T>(&content) {
                    Ok(parsed) => return Ok(parsed),
                    Err(e) => {
                        tracing::warn!(attempt = attempt + 1, error = %e, "JSON parse failed");
                        if let Ok(v) = extract_json_from_text(&content)
                            && let Ok(parsed) = serde_json::from_value::<T>(v)
                        {
                            return Ok(parsed);
                        }
                        last_err = Some(MemeError::JsonParse(format!("{e}")));
                    }
                },
                Err(e) => {
                    if !e.is_retryable() {
                        return Err(e);
                    }
                    tracing::warn!(attempt = attempt + 1, error = %e, "LLM API call failed");
                    last_err = Some(e);
                }
            }
            if attempt + 1 < self.max_retries {
                let wait = 2u64.saturating_pow(attempt).min(30);
                tokio::time::sleep(Duration::from_secs(wait)).await;
            }
        }
        Err(last_err.unwrap_or_else(|| MemeError::llm("all retries exhausted")))
    }

    /// Execute a single chat completion API request.
    async fn call_api(&self, messages: &[Message], opts: &ChatOptions) -> Result<String> {
        let url = format!("{}/chat/completions", self.base_url);

        let mut body = serde_json::json!({
            "model": self.model,
            "messages": messages,
            "temperature": opts.temperature,
            "stream": false,
        });

        if opts.json_mode
            && let Some(obj) = body.as_object_mut()
        {
            obj.insert(
                "response_format".to_owned(),
                serde_json::json!({"type": "json_object"}),
            );
        }

        let resp = self
            .http
            .post(&url)
            .header("Authorization", format!("Bearer {}", self.api_key))
            .json(&body)
            .send()
            .await?;

        if !resp.status().is_success() {
            let status = resp.status();
            let text = resp.text().await.unwrap_or_default();
            return Err(MemeError::llm_with_status(
                status.as_u16(),
                format!("API returned {status}: {text}"),
            ));
        }

        let data: serde_json::Value = resp.json().await?;
        data.get("choices")
            .and_then(|c| c.get(0))
            .and_then(|c| c.get("message"))
            .and_then(|m| m.get("content"))
            .and_then(serde_json::Value::as_str)
            .map(String::from)
            .ok_or_else(|| MemeError::llm("missing content in API response"))
    }
}