neomemx 0.1.2

A high-performance memory library for AI agents with semantic search
Documentation
//! Shared client for OpenAI-compatible APIs (Groq, OpenAI, etc.)

use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use tracing::{debug, error};

use super::base::{LlmResponse, Message, ResponseFormat, Tool, ToolCall};
use crate::error::{NeomemxError, Result};
use crate::llm::utils::extract_json;

/// Configuration for OpenAI-compatible API clients.
#[derive(Debug, Clone)]
pub struct OpenAICompatConfig {
    /// API base URL.
    pub base_url: String,
    /// API key for authentication.
    pub api_key: String,
    /// Model identifier.
    pub model: String,
    /// Sampling temperature (0.0-2.0).
    pub temperature: f32,
    /// Maximum tokens to generate.
    pub max_tokens: u32,
    /// Top-p (nucleus) sampling parameter.
    pub top_p: f32,
    /// Provider name for logging.
    pub provider_name: &'static str,
}

/// Client for OpenAI-compatible APIs (Groq, OpenAI, etc.).
pub struct OpenAICompatClient {
    config: OpenAICompatConfig,
    client: Client,
    skip_sampling_params: bool,
}

#[derive(Debug, Serialize)]
pub(crate) struct ChatCompletionRequest {
    pub model: String,
    pub messages: Vec<Message>,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub temperature: Option<f32>,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub max_tokens: Option<u32>,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub top_p: Option<f32>,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub response_format: Option<ResponseFormat>,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub tools: Option<Vec<Tool>>,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub tool_choice: Option<String>,
}

#[derive(Debug, Deserialize)]
pub(crate) struct ChatCompletionResponse {
    pub choices: Vec<Choice>,
}

#[derive(Debug, Deserialize)]
pub(crate) struct Choice {
    pub message: ResponseMessage,
}

#[derive(Debug, Deserialize)]
pub(crate) struct ResponseMessage {
    pub content: Option<Content>,
    #[serde(default)]
    pub tool_calls: Option<Vec<ApiToolCall>>,
}

/// OpenAI/Groq can return either a raw string or structured parts for `content`.
#[derive(Debug, Deserialize)]
#[serde(untagged)]
pub(crate) enum Content {
    Text(String),
    Parts(Vec<ContentPart>),
}

#[derive(Debug, Deserialize)]
pub(crate) struct ContentPart {
    #[serde(rename = "type")]
    pub kind: String,
    pub text: Option<String>,
}

#[derive(Debug, Deserialize)]
pub(crate) struct ApiToolCall {
    pub function: ApiFunction,
}

#[derive(Debug, Deserialize)]
pub(crate) struct ApiFunction {
    pub name: String,
    pub arguments: String,
}

#[derive(Debug, Deserialize)]
pub(crate) struct ErrorResponse {
    pub error: ApiError,
}

#[derive(Debug, Deserialize)]
pub(crate) struct ApiError {
    pub message: String,
    #[serde(rename = "type")]
    #[allow(dead_code)]
    pub error_type: Option<String>,
}

impl OpenAICompatClient {
    /// Creates a new client with the given configuration.
    pub fn new(config: OpenAICompatConfig) -> Self {
        let client = Client::builder()
            .pool_max_idle_per_host(8)
            .pool_idle_timeout(std::time::Duration::from_secs(90))
            .tcp_keepalive(std::time::Duration::from_secs(60))
            .no_proxy()
            .build()
            .unwrap_or_else(|_| Client::new());

        Self {
            config,
            client,
            skip_sampling_params: false,
        }
    }

    /// Configures whether to skip sampling parameters (for reasoning models).
    pub fn with_skip_sampling_params(mut self, skip: bool) -> Self {
        self.skip_sampling_params = skip;
        self
    }

    /// Sends a chat completion request to the API.
    pub async fn chat_completion(
        &self,
        messages: Vec<Message>,
        response_format: Option<ResponseFormat>,
        tools: Option<Vec<Tool>>,
        tool_choice: Option<String>,
    ) -> Result<LlmResponse> {
        let request = ChatCompletionRequest {
            model: self.config.model.clone(),
            messages,
            temperature: if self.skip_sampling_params {
                None
            } else {
                Some(self.config.temperature)
            },
            max_tokens: if self.skip_sampling_params {
                None
            } else {
                Some(self.config.max_tokens)
            },
            top_p: if self.skip_sampling_params {
                None
            } else {
                Some(self.config.top_p)
            },
            response_format,
            tools,
            tool_choice,
        };

        let url = format!("{}/chat/completions", self.config.base_url);
        debug!("Sending request to {}: {}", self.config.provider_name, url);

        let response = self
            .client
            .post(&url)
            .header("Authorization", format!("Bearer {}", self.config.api_key))
            .header("Content-Type", "application/json")
            .json(&request)
            .send()
            .await?;

        let status = response.status();
        let body = response.text().await?;

        if !status.is_success() {
            let error: ErrorResponse = serde_json::from_str(&body).unwrap_or(ErrorResponse {
                error: ApiError {
                    message: body.clone(),
                    error_type: None,
                },
            });
            error!(
                "{} API error: {}",
                self.config.provider_name, error.error.message
            );
            return Err(NeomemxError::LlmError(format!(
                "{}: {}",
                self.config.provider_name, error.error.message
            )));
        }

        let completion: ChatCompletionResponse = serde_json::from_str(&body).map_err(|e| {
            NeomemxError::LlmError(format!(
                "Failed to parse {} response: {}",
                self.config.provider_name, e
            ))
        })?;

        let choice = completion.choices.into_iter().next().ok_or_else(|| {
            NeomemxError::LlmError(format!(
                "No choices in {} response",
                self.config.provider_name
            ))
        })?;

        let content_text = choice
            .message
            .content
            .map(|c| match c {
                Content::Text(t) => t,
                Content::Parts(parts) => parts
                    .into_iter()
                    .filter_map(|p| p.text)
                    .collect::<Vec<String>>()
                    .join(""),
            })
            .unwrap_or_default();

        if let Some(api_tool_calls) = choice.message.tool_calls {
            if !api_tool_calls.is_empty() {
                let tool_calls: Vec<ToolCall> = api_tool_calls
                    .iter()
                    .filter_map(|tc| Self::parse_tool_call(tc).ok())
                    .collect();

                return Ok(LlmResponse::WithToolCalls {
                    content: Some(content_text.clone()).filter(|s| !s.is_empty()),
                    tool_calls,
                });
            }
        }

        Ok(LlmResponse::Text(content_text))
    }

    fn parse_tool_call(api_call: &ApiToolCall) -> Result<ToolCall> {
        let arguments: HashMap<String, serde_json::Value> =
            extract_json(&api_call.function.arguments)?;

        Ok(ToolCall {
            name: api_call.function.name.clone(),
            arguments,
        })
    }
}