mod http;
use crate::Error;
use reqwest::Client;
use schemars::{schema_for, JsonSchema};
use serde::{Deserialize, Serialize};
use super::chat_completion::{ChatCompletionRequest, JsonSchemaFormat, ResponseFormat};
use super::options::LLMOptions;
use super::responses::{ResponseObject, ResponseRequest};
use super::tool_registry::ToolRegistry;
use super::types::{Endpoint, FunctionCall, LLMConfig, LLMResponse, Message, Method, Mode};
use super::utils::{
extract_answer, finish_reason_is_tool_calls, handle_tool_calls, parse_chunks_to_llm_response,
strip_thinking,
};
use serde_json::Value;
use std::sync::Arc;
pub struct PlainLLM {
pub token: String,
pub api_url: String,
http_client: Client,
config: LLMConfig,
}
impl PlainLLM {
pub fn new(api_url: &str, token: &str) -> Self {
tracing::info!("Creating PlainLLM client");
tracing::debug!("api_url: {}", api_url);
Self {
token: token.to_owned(),
api_url: api_url.to_owned(),
http_client: Client::new(),
config: LLMConfig::default(),
}
}
pub fn new_with_config(api_url: &str, token: &str, config: LLMConfig) -> Self {
tracing::info!("Creating PlainLLM client with config");
tracing::debug!("api_url: {}", api_url);
Self {
token: token.to_owned(),
api_url: api_url.to_owned(),
http_client: Client::new(),
config,
}
}
pub async fn ask(
&self,
model: &str,
user_content: &str,
opts: &LLMOptions<'_>,
) -> Result<String, Error> {
tracing::info!("ask");
tracing::debug!("model: {}", model);
tracing::trace!("user_content: {}", user_content);
let messages = vec![Message::new("user", user_content)];
let (llm_response, _new_messages) = self.call_llm(model, messages, opts).await?;
let answer = extract_answer(llm_response);
tracing::debug!("answer: {}", answer);
Ok(answer)
}
pub async fn call_llm(
&self,
model: &str,
mut messages: Vec<Message>,
opts: &LLMOptions<'_>,
) -> Result<(LLMResponse, Vec<Message>), Error> {
tracing::info!("call_llm");
tracing::debug!("model: {} streaming: {}", model, opts.streaming);
tracing::trace!("messages: {:?}", messages);
if self.config.mode == Mode::Responses {
return self.call_responses(model, messages, opts).await;
}
let mut request = ChatCompletionRequest::new(model.to_string(), messages.clone());
request.stream = opts.streaming;
if let Some(t) = opts.temperature {
request.temperature = Some(t);
}
if let Some(p) = opts.top_p {
request.top_p = Some(p);
}
if let Some(max) = opts.max_tokens {
request.max_tokens = Some(max);
}
if let Some(ref stop) = opts.stop {
request.stop = Some(stop.clone());
}
if let Some(p) = opts.presence_penalty {
request.presence_penalty = Some(p);
}
if let Some(p) = opts.frequency_penalty {
request.frequency_penalty = Some(p);
}
if let Some(p) = opts.top_k {
request.top_k = Some(p);
}
if let Some(p) = opts.repeat_penalty {
request.repeat_penalty = Some(p);
}
if let Some(p) = &opts.context_overflow_policy {
request.context_overflow_policy = Some(p.to_string());
}
if let Some(registry) = opts.tools {
request.tools = Some(registry.to_api_tools());
}
if opts.streaming {
let (chunks, partial_content) = self.stream_llm(&request, &opts.event_handlers).await?;
let mut final_response = parse_chunks_to_llm_response(chunks, partial_content)?;
if finish_reason_is_tool_calls(&final_response) {
messages = handle_tool_calls(
&mut final_response,
messages,
opts.tools,
&opts.event_handlers,
)
.await?;
final_response = Box::pin(self.call_llm(model, messages.clone(), opts))
.await?
.0;
}
tracing::debug!("final_response: {:?}", final_response);
Ok((final_response, messages))
} else {
let text = self
.http_call(Endpoint::ChatCompletion, Method::Post, Some(&request))
.await?;
let mut final_response: LLMResponse =
serde_json::from_str(&text).map_err(Error::Json)?;
if finish_reason_is_tool_calls(&final_response) {
messages = handle_tool_calls(
&mut final_response,
messages,
opts.tools,
&opts.event_handlers,
)
.await?;
final_response = Box::pin(self.call_llm(model, messages.clone(), opts))
.await?
.0;
}
tracing::debug!("final_response: {:?}", final_response);
Ok((final_response, messages))
}
}
async fn call_responses(
&self,
model: &str,
messages: Vec<Message>,
opts: &LLMOptions<'_>,
) -> Result<(LLMResponse, Vec<Message>), Error> {
if opts.streaming {
return Err(Error::Message("Streaming not supported".into()));
}
let mut input_text = String::new();
for m in &messages {
if let Some(ref c) = m.content {
input_text.push_str(&format!("{}: {}\n", m.role, c));
}
}
let mut req = ResponseRequest::new(
model.to_string(),
Value::String(input_text.trim().to_string()),
);
if let Some(t) = opts.temperature {
req.temperature = Some(t);
}
if let Some(p) = opts.top_p {
req.top_p = Some(p);
}
if let Some(max) = opts.max_tokens {
req.max_output_tokens = Some(max);
}
if let Some(reg) = opts.tools {
req.tools = Some(reg.to_api_tools());
}
let instructions = messages
.iter()
.find(|m| m.role == "system")
.and_then(|m| m.content.clone());
req.instructions = instructions;
let text = self
.http_call(Endpoint::Responses, Method::Post, Some(&req))
.await?;
let resp: ResponseObject = serde_json::from_str(&text).map_err(Error::Json)?;
Ok((resp.into(), messages))
}
pub async fn call_llm_structured<T>(
&self,
model: &str,
mut messages: Vec<Message>,
opts: &LLMOptions<'_>,
) -> Result<T, Error>
where
T: for<'de> Deserialize<'de> + JsonSchema + Serialize + std::fmt::Debug,
{
tracing::info!("call_llm_structured");
tracing::debug!("model: {} streaming: {}", model, opts.streaming);
tracing::trace!("messages: {:?}", messages);
if self.config.mode == Mode::Responses {
let (resp, _msgs) = self.call_llm(model, messages, opts).await?;
let answer_text = extract_answer(resp);
let cleaned = strip_thinking(&answer_text);
let structured: T = serde_json::from_str(&cleaned).map_err(Error::Json)?;
tracing::debug!("structured: {:?}", structured);
return Ok(structured);
}
let schema = schema_for!(T);
let mut json_schema_value = serde_json::to_value(schema)?;
if let serde_json::Value::Object(ref mut map) = json_schema_value {
map.insert(
"additionalProperties".to_string(),
serde_json::Value::Bool(false),
);
if let Some(props) = map.get("properties").and_then(|v| v.as_object()) {
let all_keys: Vec<serde_json::Value> = props
.keys()
.map(|k| serde_json::Value::String(k.clone()))
.collect();
map.insert("required".to_string(), serde_json::Value::Array(all_keys));
}
}
let response_format = JsonSchemaFormat {
name: "my_schema".to_string(),
strict: true,
schema: json_schema_value,
};
let format = ResponseFormat {
r#type: "json_schema".to_string(),
json_schema: response_format,
};
let mut request = ChatCompletionRequest::new(model.to_string(), messages.clone());
request.from_llm_options(&opts);
request.with_response_format(format);
if let Some(registry) = opts.tools {
request.tools = Some(registry.to_api_tools());
}
let final_response = if opts.streaming {
let (chunks, partial_content) = self.stream_llm(&request, &opts.event_handlers).await?;
let mut resp = parse_chunks_to_llm_response(chunks, partial_content)?;
if finish_reason_is_tool_calls(&resp) {
messages = handle_tool_calls(&mut resp, messages, opts.tools, &opts.event_handlers)
.await?;
resp = self.call_llm(model, messages, opts).await?.0;
}
resp
} else {
let text = self
.http_call(Endpoint::ChatCompletion, Method::Post, Some(&request))
.await?;
let mut resp: LLMResponse = serde_json::from_str(&text).map_err(Error::Json)?;
if finish_reason_is_tool_calls(&resp) {
messages = handle_tool_calls(&mut resp, messages, opts.tools, &opts.event_handlers)
.await?;
resp = self.call_llm(model, messages, opts).await?.0;
}
resp
};
let answer_text = extract_answer(final_response);
let cleaned = strip_thinking(&answer_text);
let structured: T = serde_json::from_str(&cleaned).map_err(Error::Json)?;
tracing::debug!("structured: {:?}", structured);
Ok(structured)
}
}
impl PlainLLM {
pub async fn call_llm_with_tools(
llm: &PlainLLM,
model: &str,
messages: Vec<Message>,
registry: &ToolRegistry,
on_call: Option<Arc<dyn Fn(&FunctionCall) + Send + Sync>>,
on_result: Option<Arc<dyn Fn(&FunctionCall, &Result<Value, String>) + Send + Sync>>,
) -> Result<(LLMResponse, Vec<Message>), Error> {
tracing::info!("call_llm_with_tools");
let mut new_messages = messages;
tracing::debug!("model: {}", model);
tracing::trace!("messages: {:?}", new_messages);
let mut request = ChatCompletionRequest::new(model.to_string(), new_messages.clone());
request.tools = Some(registry.to_api_tools());
let text = llm
.http_call(Endpoint::ChatCompletion, Method::Post, Some(&request))
.await?;
let llm_response: LLMResponse = serde_json::from_str(&text).map_err(Error::Json)?;
tracing::debug!("raw llm_response: {:?}", llm_response);
if finish_reason_is_tool_calls(&llm_response) {
let message = llm_response
.choices
.first()
.and_then(|c| c.message.as_ref())
.ok_or_else(|| Error::Message("No message".into()))?;
new_messages.push(message.clone());
let calls = message
.tool_calls
.as_ref()
.ok_or_else(|| Error::Message("No tool calls".into()))?;
for call in calls {
tracing::debug!("tool call {}", call.function.name);
let args_str = call.function.arguments.as_str().unwrap_or_default();
tracing::trace!("tool args: {}", args_str);
if let Some(ref cb) = on_call {
cb(call);
}
let args: Value = serde_json::from_str(args_str).unwrap_or_default();
let res = registry.call(&call.function.name, args).await;
tracing::debug!("tool result: {:?}", res);
if let Some(ref cb) = on_result {
cb(call, &res);
}
let tool_call_message = Message {
role: "tool".to_string(),
content: Some(format!("{:?}", res)),
tool_calls: None,
tool_call_id: Some(call.id.clone()),
};
new_messages.push(tool_call_message);
}
Box::pin(Self::call_llm_with_tools(
llm,
model,
new_messages,
registry,
on_call.clone(),
on_result.clone(),
))
.await
} else {
if let Some(msg) = llm_response.choices.get(0).and_then(|c| c.message.clone()) {
new_messages.push(msg);
}
tracing::debug!("returning llm_response: {:?}", llm_response);
Ok((llm_response, new_messages))
}
}
}