plainllm 1.2.0

A plain & simple LLM client
Documentation
mod http;
use crate::Error;
use reqwest::Client;
use schemars::{schema_for, JsonSchema};
use serde::{Deserialize, Serialize};

use super::chat_completion::{ChatCompletionRequest, JsonSchemaFormat, ResponseFormat};
use super::options::LLMOptions;
use super::responses::{ResponseObject, ResponseRequest};
use super::tool_registry::ToolRegistry;
use super::types::{Endpoint, FunctionCall, LLMConfig, LLMResponse, Message, Method, Mode};
use super::utils::{
    extract_answer, finish_reason_is_tool_calls, handle_tool_calls, parse_chunks_to_llm_response,
    strip_thinking,
};
use serde_json::Value;
use std::sync::Arc;

/// A simplified LLM client that can handle single-shot calls, general calls, or structured calls.
pub struct PlainLLM {
    pub token: String,
    pub api_url: String,
    http_client: Client,
    config: LLMConfig,
}

impl PlainLLM {
    pub fn new(api_url: &str, token: &str) -> Self {
        tracing::info!("Creating PlainLLM client");
        tracing::debug!("api_url: {}", api_url);
        Self {
            token: token.to_owned(),
            api_url: api_url.to_owned(),
            http_client: Client::new(),
            config: LLMConfig::default(),
        }
    }

    pub fn new_with_config(api_url: &str, token: &str, config: LLMConfig) -> Self {
        tracing::info!("Creating PlainLLM client with config");
        tracing::debug!("api_url: {}", api_url);
        Self {
            token: token.to_owned(),
            api_url: api_url.to_owned(),
            http_client: Client::new(),
            config,
        }
    }

    //////////////////////////////////////////
    // 1) Single-turn Q&A returning text
    //////////////////////////////////////////
    pub async fn ask(
        &self,
        model: &str,
        user_content: &str,
        opts: &LLMOptions<'_>,
    ) -> Result<String, Error> {
        tracing::info!("ask");
        tracing::debug!("model: {}", model);
        tracing::trace!("user_content: {}", user_content);
        let messages = vec![Message::new("user", user_content)];
        let (llm_response, _new_messages) = self.call_llm(model, messages, opts).await?;
        let answer = extract_answer(llm_response);
        tracing::debug!("answer: {}", answer);
        Ok(answer)
    }

    //////////////////////////////////////////
    // 2) The general method
    //////////////////////////////////////////
    pub async fn call_llm(
        &self,
        model: &str,
        mut messages: Vec<Message>,
        opts: &LLMOptions<'_>,
    ) -> Result<(LLMResponse, Vec<Message>), Error> {
        tracing::info!("call_llm");
        tracing::debug!("model: {} streaming: {}", model, opts.streaming);
        tracing::trace!("messages: {:?}", messages);

        if self.config.mode == Mode::Responses {
            return self.call_responses(model, messages, opts).await;
        }
        let mut request = ChatCompletionRequest::new(model.to_string(), messages.clone());
        request.stream = opts.streaming;

        if let Some(t) = opts.temperature {
            request.temperature = Some(t);
        }

        if let Some(p) = opts.top_p {
            request.top_p = Some(p);
        }

        if let Some(max) = opts.max_tokens {
            request.max_tokens = Some(max);
        }

        if let Some(ref stop) = opts.stop {
            request.stop = Some(stop.clone());
        }

        if let Some(p) = opts.presence_penalty {
            request.presence_penalty = Some(p);
        }

        if let Some(p) = opts.frequency_penalty {
            request.frequency_penalty = Some(p);
        }

        if let Some(p) = opts.top_k {
            request.top_k = Some(p);
        }

        if let Some(p) = opts.repeat_penalty {
            request.repeat_penalty = Some(p);
        }

        if let Some(p) = &opts.context_overflow_policy {
            request.context_overflow_policy = Some(p.to_string());
        }

        if let Some(registry) = opts.tools {
            request.tools = Some(registry.to_api_tools());
        }

        // Non-structured normal chat
        if opts.streaming {
            // streaming path
            let (chunks, partial_content) = self.stream_llm(&request, &opts.event_handlers).await?;
            let mut final_response = parse_chunks_to_llm_response(chunks, partial_content)?;

            // Possibly handle tool calls
            if finish_reason_is_tool_calls(&final_response) {
                messages = handle_tool_calls(
                    &mut final_response,
                    messages,
                    opts.tools,
                    &opts.event_handlers,
                )
                .await?;
                // Re-invoke with updated messages
                final_response = Box::pin(self.call_llm(model, messages.clone(), opts))
                    .await?
                    .0;
            }
            tracing::debug!("final_response: {:?}", final_response);
            Ok((final_response, messages))
        } else {
            // non-streaming
            let text = self
                .http_call(Endpoint::ChatCompletion, Method::Post, Some(&request))
                .await?;
            let mut final_response: LLMResponse =
                serde_json::from_str(&text).map_err(Error::Json)?;

            // Possibly handle tool calls
            if finish_reason_is_tool_calls(&final_response) {
                messages = handle_tool_calls(
                    &mut final_response,
                    messages,
                    opts.tools,
                    &opts.event_handlers,
                )
                .await?;
                final_response = Box::pin(self.call_llm(model, messages.clone(), opts))
                    .await?
                    .0;
            }
            tracing::debug!("final_response: {:?}", final_response);
            Ok((final_response, messages))
        }
    }

    async fn call_responses(
        &self,
        model: &str,
        messages: Vec<Message>,
        opts: &LLMOptions<'_>,
    ) -> Result<(LLMResponse, Vec<Message>), Error> {
        if opts.streaming {
            return Err(Error::Message("Streaming not supported".into()));
        }
        let mut input_text = String::new();
        for m in &messages {
            if let Some(ref c) = m.content {
                input_text.push_str(&format!("{}: {}\n", m.role, c));
            }
        }
        let mut req = ResponseRequest::new(
            model.to_string(),
            Value::String(input_text.trim().to_string()),
        );
        if let Some(t) = opts.temperature {
            req.temperature = Some(t);
        }
        if let Some(p) = opts.top_p {
            req.top_p = Some(p);
        }
        if let Some(max) = opts.max_tokens {
            req.max_output_tokens = Some(max);
        }
        if let Some(reg) = opts.tools {
            req.tools = Some(reg.to_api_tools());
        }
        let instructions = messages
            .iter()
            .find(|m| m.role == "system")
            .and_then(|m| m.content.clone());
        req.instructions = instructions;
        let text = self
            .http_call(Endpoint::Responses, Method::Post, Some(&req))
            .await?;
        let resp: ResponseObject = serde_json::from_str(&text).map_err(Error::Json)?;
        Ok((resp.into(), messages))
    }

    //////////////////////////////////////////
    // 3) Structured output method
    //////////////////////////////////////////
    pub async fn call_llm_structured<T>(
        &self,
        model: &str,
        mut messages: Vec<Message>,
        opts: &LLMOptions<'_>,
    ) -> Result<T, Error>
    where
        T: for<'de> Deserialize<'de> + JsonSchema + Serialize + std::fmt::Debug,
    {
        tracing::info!("call_llm_structured");
        tracing::debug!("model: {} streaming: {}", model, opts.streaming);
        tracing::trace!("messages: {:?}", messages);

        if self.config.mode == Mode::Responses {
            let (resp, _msgs) = self.call_llm(model, messages, opts).await?;
            let answer_text = extract_answer(resp);
            let cleaned = strip_thinking(&answer_text);
            let structured: T = serde_json::from_str(&cleaned).map_err(Error::Json)?;
            tracing::debug!("structured: {:?}", structured);
            return Ok(structured);
        }
        // 1) Build a JSON schema for `T`
        let schema = schema_for!(T);
        let mut json_schema_value = serde_json::to_value(schema)?;
        if let serde_json::Value::Object(ref mut map) = json_schema_value {
            map.insert(
                "additionalProperties".to_string(),
                serde_json::Value::Bool(false),
            );
            if let Some(props) = map.get("properties").and_then(|v| v.as_object()) {
                let all_keys: Vec<serde_json::Value> = props
                    .keys()
                    .map(|k| serde_json::Value::String(k.clone()))
                    .collect();
                map.insert("required".to_string(), serde_json::Value::Array(all_keys));
            }
        }

        // 2) Create the "json_schema" response format
        let response_format = JsonSchemaFormat {
            name: "my_schema".to_string(),
            strict: true,
            schema: json_schema_value,
        };
        let format = ResponseFormat {
            r#type: "json_schema".to_string(),
            json_schema: response_format,
        };

        let mut request = ChatCompletionRequest::new(model.to_string(), messages.clone());
        request.from_llm_options(&opts);
        request.with_response_format(format);

        if let Some(registry) = opts.tools {
            request.tools = Some(registry.to_api_tools());
        }

        // 3) Execute, possibly streaming
        let final_response = if opts.streaming {
            let (chunks, partial_content) = self.stream_llm(&request, &opts.event_handlers).await?;
            let mut resp = parse_chunks_to_llm_response(chunks, partial_content)?;
            if finish_reason_is_tool_calls(&resp) {
                messages = handle_tool_calls(&mut resp, messages, opts.tools, &opts.event_handlers)
                    .await?;
                // Re-invoke ourselves after tool calls
                resp = self.call_llm(model, messages, opts).await?.0;
            }
            resp
        } else {
            // Non-streaming
            let text = self
                .http_call(Endpoint::ChatCompletion, Method::Post, Some(&request))
                .await?;
            let mut resp: LLMResponse = serde_json::from_str(&text).map_err(Error::Json)?;
            if finish_reason_is_tool_calls(&resp) {
                messages = handle_tool_calls(&mut resp, messages, opts.tools, &opts.event_handlers)
                    .await?;
                resp = self.call_llm(model, messages, opts).await?.0;
            }
            resp
        };

        // 4) Attempt to parse the final content into `T`
        let answer_text = extract_answer(final_response);
        let cleaned = strip_thinking(&answer_text);
        let structured: T = serde_json::from_str(&cleaned).map_err(Error::Json)?;
        tracing::debug!("structured: {:?}", structured);
        Ok(structured)
    }
}

impl PlainLLM {
    /// Convenience helper to call the LLM with a ToolRegistry.
    /// This mirrors the older `call_llm_with_tools` API used by the examples.
    pub async fn call_llm_with_tools(
        llm: &PlainLLM,
        model: &str,
        messages: Vec<Message>,
        registry: &ToolRegistry,
        on_call: Option<Arc<dyn Fn(&FunctionCall) + Send + Sync>>,
        on_result: Option<Arc<dyn Fn(&FunctionCall, &Result<Value, String>) + Send + Sync>>,
    ) -> Result<(LLMResponse, Vec<Message>), Error> {
        tracing::info!("call_llm_with_tools");
        let mut new_messages = messages;
        tracing::debug!("model: {}", model);
        tracing::trace!("messages: {:?}", new_messages);

        let mut request = ChatCompletionRequest::new(model.to_string(), new_messages.clone());
        request.tools = Some(registry.to_api_tools());

        let text = llm
            .http_call(Endpoint::ChatCompletion, Method::Post, Some(&request))
            .await?;
        let llm_response: LLMResponse = serde_json::from_str(&text).map_err(Error::Json)?;
        tracing::debug!("raw llm_response: {:?}", llm_response);

        if finish_reason_is_tool_calls(&llm_response) {
            let message = llm_response
                .choices
                .first()
                .and_then(|c| c.message.as_ref())
                .ok_or_else(|| Error::Message("No message".into()))?;

            new_messages.push(message.clone());

            let calls = message
                .tool_calls
                .as_ref()
                .ok_or_else(|| Error::Message("No tool calls".into()))?;

            for call in calls {
                tracing::debug!("tool call {}", call.function.name);
                let args_str = call.function.arguments.as_str().unwrap_or_default();
                tracing::trace!("tool args: {}", args_str);
                if let Some(ref cb) = on_call {
                    cb(call);
                }
                let args: Value = serde_json::from_str(args_str).unwrap_or_default();
                let res = registry.call(&call.function.name, args).await;
                tracing::debug!("tool result: {:?}", res);

                if let Some(ref cb) = on_result {
                    cb(call, &res);
                }

                let tool_call_message = Message {
                    role: "tool".to_string(),
                    content: Some(format!("{:?}", res)),
                    tool_calls: None,
                    tool_call_id: Some(call.id.clone()),
                };

                new_messages.push(tool_call_message);
            }

            Box::pin(Self::call_llm_with_tools(
                llm,
                model,
                new_messages,
                registry,
                on_call.clone(),
                on_result.clone(),
            ))
            .await
        } else {
            if let Some(msg) = llm_response.choices.get(0).and_then(|c| c.message.clone()) {
                new_messages.push(msg);
            }
            tracing::debug!("returning llm_response: {:?}", llm_response);
            Ok((llm_response, new_messages))
        }
    }
}