use std::pin::Pin;
use async_trait::async_trait;
use futures::Stream;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::error::Result;
use crate::message::Message;
use crate::stream::{StopReason, StreamChunk};
use crate::tool::ToolDefinition;
use crate::usage::Usage;
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ReasoningEffort {
None,
Minimal,
Low,
#[default]
Medium,
High,
#[serde(rename = "xhigh")]
XHigh,
}
impl ReasoningEffort {
#[must_use]
pub const fn as_str(&self) -> &'static str {
match self {
Self::None => "none",
Self::Minimal => "minimal",
Self::Low => "low",
Self::Medium => "medium",
Self::High => "high",
Self::XHigh => "xhigh",
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ChatRequest {
#[serde(default)]
pub model: String,
#[serde(default)]
pub messages: Vec<Message>,
#[serde(skip_serializing_if = "Option::is_none", alias = "max_tokens")]
pub max_completion_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub n: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<ToolDefinition>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub parallel_tool_calls: Option<bool>,
#[serde(default)]
pub stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub response_format: Option<ResponseFormat>,
#[serde(skip_serializing_if = "Option::is_none")]
pub seed: Option<i64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub presence_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_logprobs: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub service_tier: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub store: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<std::collections::HashMap<String, String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning_effort: Option<ReasoningEffort>,
}
impl ChatRequest {
#[must_use]
pub fn new(model: impl Into<String>) -> Self {
Self {
model: model.into(),
..Default::default()
}
}
#[must_use]
pub fn with_messages(model: impl Into<String>, messages: Vec<Message>) -> Self {
Self {
model: model.into(),
messages,
..Default::default()
}
}
#[must_use]
pub fn system(mut self, content: impl Into<String>) -> Self {
self.messages.push(Message::system(content));
self
}
#[must_use]
pub fn user(mut self, content: impl Into<String>) -> Self {
self.messages.push(Message::user(content));
self
}
#[must_use]
pub fn assistant(mut self, content: impl Into<String>) -> Self {
self.messages.push(Message::assistant(content));
self
}
#[must_use]
pub fn message(mut self, message: Message) -> Self {
self.messages.push(message);
self
}
#[must_use]
pub fn messages(mut self, messages: Vec<Message>) -> Self {
self.messages = messages;
self
}
#[must_use]
pub const fn max_completion_tokens(mut self, tokens: u32) -> Self {
self.max_completion_tokens = Some(tokens);
self
}
#[must_use]
pub const fn temperature(mut self, temperature: f32) -> Self {
self.temperature = Some(temperature);
self
}
#[must_use]
pub const fn top_p(mut self, top_p: f32) -> Self {
self.top_p = Some(top_p);
self
}
#[must_use]
pub const fn n(mut self, n: u32) -> Self {
self.n = Some(n);
self
}
#[must_use]
pub fn stop(mut self, stop: Vec<String>) -> Self {
self.stop = Some(stop);
self
}
#[must_use]
pub fn tools(mut self, tools: Vec<ToolDefinition>) -> Self {
self.tools = Some(tools);
self
}
#[must_use]
pub fn tool_choice(mut self, choice: impl Into<ToolChoice>) -> Self {
self.tool_choice = Some(choice.into().to_value());
self
}
#[must_use]
pub const fn parallel_tool_calls(mut self, enabled: bool) -> Self {
self.parallel_tool_calls = Some(enabled);
self
}
#[must_use]
pub const fn stream(mut self) -> Self {
self.stream = true;
self
}
#[must_use]
pub fn response_format(mut self, format: ResponseFormat) -> Self {
self.response_format = Some(format);
self
}
#[cfg(feature = "schema")]
#[must_use]
pub fn output_type<T: schemars::JsonSchema>(self) -> Self {
self.response_format(ResponseFormat::from_type::<T>())
}
#[must_use]
pub const fn seed(mut self, seed: i64) -> Self {
self.seed = Some(seed);
self
}
#[must_use]
pub fn user_id(mut self, user: impl Into<String>) -> Self {
self.user = Some(user.into());
self
}
#[must_use]
pub const fn frequency_penalty(mut self, penalty: f32) -> Self {
self.frequency_penalty = Some(penalty);
self
}
#[must_use]
pub const fn presence_penalty(mut self, penalty: f32) -> Self {
self.presence_penalty = Some(penalty);
self
}
#[must_use]
pub const fn logprobs(mut self, enabled: bool) -> Self {
self.logprobs = Some(enabled);
self
}
#[must_use]
pub fn service_tier(mut self, tier: impl Into<String>) -> Self {
self.service_tier = Some(tier.into());
self
}
#[must_use]
pub const fn reasoning_effort(mut self, effort: ReasoningEffort) -> Self {
self.reasoning_effort = Some(effort);
self
}
}
#[derive(Debug, Clone, Default)]
pub enum ToolChoice {
#[default]
Auto,
Required,
None,
Function(String),
}
impl ToolChoice {
#[must_use]
pub fn to_value(&self) -> Value {
match self {
Self::Auto => Value::String("auto".to_owned()),
Self::Required => Value::String("required".to_owned()),
Self::None => Value::String("none".to_owned()),
Self::Function(name) => serde_json::json!({
"type": "function",
"function": {"name": name}
}),
}
}
}
impl From<&str> for ToolChoice {
fn from(s: &str) -> Self {
match s {
"auto" => Self::Auto,
"required" => Self::Required,
"none" => Self::None,
name => Self::Function(name.to_owned()),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ResponseFormat {
Text,
JsonObject,
JsonSchema {
json_schema: JsonSchemaSpec,
},
}
impl ResponseFormat {
#[must_use]
pub const fn json() -> Self {
Self::JsonObject
}
#[must_use]
pub fn json_schema(name: impl Into<String>, schema: Value) -> Self {
Self::JsonSchema {
json_schema: JsonSchemaSpec {
name: name.into(),
schema,
strict: Some(true),
},
}
}
#[cfg(feature = "schema")]
#[must_use]
pub fn from_type<T: schemars::JsonSchema>() -> Self {
let (name, schema_value) = generate_json_schema::<T>();
Self::json_schema(name, schema_value)
}
}
#[cfg(feature = "schema")]
#[must_use]
pub fn generate_json_schema<T: schemars::JsonSchema>() -> (String, Value) {
let root = schemars::schema_for!(T);
let mut schema_value = serde_json::to_value(&root).unwrap_or_default();
if let Value::Object(ref mut map) = schema_value {
map.remove("$schema");
}
let name = <T as schemars::JsonSchema>::schema_name();
(name.into_owned(), schema_value)
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JsonSchemaSpec {
pub name: String,
pub schema: Value,
#[serde(skip_serializing_if = "Option::is_none")]
pub strict: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatResponse {
pub message: Message,
pub stop_reason: StopReason,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage: Option<Usage>,
#[serde(skip_serializing_if = "Option::is_none")]
pub model: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub service_tier: Option<String>,
#[serde(skip)]
pub raw: Option<Value>,
}
impl ChatResponse {
#[must_use]
pub const fn new(message: Message) -> Self {
Self {
message,
stop_reason: StopReason::Stop,
usage: None,
model: None,
id: None,
service_tier: None,
raw: None,
}
}
#[must_use]
pub fn from_text(content: impl Into<String>) -> Self {
Self::new(Message::assistant(content))
}
#[must_use]
pub const fn with_stop_reason(mut self, reason: StopReason) -> Self {
self.stop_reason = reason;
self
}
#[must_use]
pub const fn with_usage(mut self, usage: Usage) -> Self {
self.usage = Some(usage);
self
}
#[must_use]
pub fn with_model(mut self, model: impl Into<String>) -> Self {
self.model = Some(model.into());
self
}
#[must_use]
pub fn with_id(mut self, id: impl Into<String>) -> Self {
self.id = Some(id.into());
self
}
#[must_use]
pub fn with_raw(mut self, raw: Value) -> Self {
self.raw = Some(raw);
self
}
#[must_use]
pub fn text(&self) -> Option<String> {
self.message.text()
}
pub fn parse<T: serde::de::DeserializeOwned>(&self) -> serde_json::Result<T> {
let text = self.text().unwrap_or_default();
serde_json::from_str(&text)
}
#[must_use]
pub fn has_tool_calls(&self) -> bool {
self.message.has_tool_calls()
}
#[must_use]
pub fn tool_calls(&self) -> Option<&[crate::message::ToolCall]> {
self.message.tool_calls.as_deref()
}
#[must_use]
pub const fn is_complete(&self) -> bool {
self.stop_reason.is_complete()
}
#[must_use]
pub const fn is_truncated(&self) -> bool {
self.stop_reason.is_truncated()
}
}
impl Default for ChatResponse {
fn default() -> Self {
Self::new(Message::default())
}
}
#[async_trait]
pub trait ChatProvider: Send + Sync {
async fn chat(&self, request: &ChatRequest) -> Result<ChatResponse>;
async fn chat_stream(
&self,
request: &ChatRequest,
) -> Result<Pin<Box<dyn Stream<Item = Result<StreamChunk>> + Send>>> {
let _ = request;
Err(crate::error::LlmError::not_supported("streaming").into())
}
fn provider_name(&self) -> &'static str;
fn default_model(&self) -> &str;
fn supports_streaming(&self) -> bool {
false
}
fn supports_tools(&self) -> bool {
true
}
fn supports_vision(&self) -> bool {
false
}
fn supports_json_mode(&self) -> bool {
false
}
async fn complete(&self, prompt: &str) -> Result<String> {
let request = ChatRequest::new(self.default_model()).user(prompt);
let response = self.chat(&request).await?;
Ok(response.text().unwrap_or_default())
}
async fn complete_with_system(&self, system: &str, prompt: &str) -> Result<String> {
let request = ChatRequest::new(self.default_model())
.system(system)
.user(prompt);
let response = self.chat(&request).await?;
Ok(response.text().unwrap_or_default())
}
}
pub type SharedChatProvider = std::sync::Arc<dyn ChatProvider>;