use std::collections::HashMap;
use std::fmt;
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use chrono::{DateTime, Utc};
use secrecy::SecretString;
use serde::{Deserialize, Serialize};
use crate::chat::Message;
use crate::tools::Tool;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub enum ToolChoice {
#[serde(rename = "auto")]
Auto,
#[serde(rename = "none")]
None,
#[serde(rename = "required")]
Required,
Function {
name: String,
},
}
impl fmt::Display for ToolChoice {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ToolChoice::Auto => write!(f, "auto"),
ToolChoice::None => write!(f, "none"),
ToolChoice::Required => write!(f, "required"),
ToolChoice::Function { name } => write!(f, "{}", name),
}
}
}
impl From<ToolChoice> for serde_json::Value {
fn from(tool_choice: ToolChoice) -> Self {
match tool_choice {
ToolChoice::Auto => serde_json::Value::String("auto".to_string()),
ToolChoice::None => serde_json::Value::String("none".to_string()),
ToolChoice::Required => serde_json::Value::String("required".to_string()),
ToolChoice::Function { name } => {
serde_json::json!({
"type": "function",
"function": {
"name": name
}
})
}
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Copy)]
#[non_exhaustive]
pub enum FinishReason {
#[serde(rename = "stop")]
Stop,
#[serde(rename = "length")]
Length,
#[serde(rename = "tool_calls")]
ToolCalls,
#[serde(rename = "content_filter")]
ContentFilter,
#[serde(rename = "model_error")]
ModelError,
}
impl fmt::Display for FinishReason {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
FinishReason::Stop => write!(f, "stop"),
FinishReason::Length => write!(f, "length"),
FinishReason::ToolCalls => write!(f, "tool_calls"),
FinishReason::ContentFilter => write!(f, "content_filter"),
FinishReason::ModelError => write!(f, "model_error"),
}
}
}
impl FromStr for FinishReason {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"stop" => Ok(FinishReason::Stop),
"length" => Ok(FinishReason::Length),
"tool_calls" => Ok(FinishReason::ToolCalls),
"content_filter" => Ok(FinishReason::ContentFilter),
"model_error" => Ok(FinishReason::ModelError),
_ => anyhow::bail!("Unknown finish reason: {}", s),
}
}
}
#[derive(Debug, Clone)]
pub struct RetryConfig {
pub max_retries: usize,
pub initial_delay: Duration,
pub max_delay: Duration,
pub backoff_multiplier: f64,
pub jitter: bool,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_retries: 3,
initial_delay: Duration::from_millis(1000),
max_delay: Duration::from_secs(30),
backoff_multiplier: 2.0,
jitter: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Usage {
#[serde(alias = "input_tokens")]
pub prompt_tokens: u32,
#[serde(alias = "output_tokens")]
pub completion_tokens: u32,
pub total_tokens: u32,
pub cost: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub input_tokens_details: Option<InputTokensDetails>,
#[serde(skip_serializing_if = "Option::is_none")]
pub output_tokens_details: Option<OutputTokensDetails>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InputTokensDetails {
pub cached_tokens: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OutputTokensDetails {
pub reasoning_tokens: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatRequest {
pub messages: Arc<[Message]>,
pub model: Option<String>,
pub temperature: Option<f32>,
pub max_tokens: Option<u32>,
pub top_p: Option<f32>,
pub frequency_penalty: Option<f32>,
pub presence_penalty: Option<f32>,
pub stop: Option<Vec<String>>,
pub tools: Option<Vec<Tool>>,
pub tool_choice: Option<ToolChoice>,
pub stream: bool,
pub user: Option<String>,
pub enable_thinking: Option<bool>,
pub metadata: HashMap<String, serde_json::Value>,
}
impl fmt::Display for ChatRequest {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match serde_json::to_string(self) {
Ok(json) => write!(f, "{}", json),
Err(_) => write!(f, "Error serializing ChatRequest to JSON"),
}
}
}
impl ChatRequest {
pub fn new(messages: impl Into<Arc<[Message]>>) -> Self {
Self {
messages: messages.into(),
model: None,
temperature: None,
max_tokens: None,
top_p: None,
frequency_penalty: None,
presence_penalty: None,
stop: None,
tools: None,
tool_choice: None,
stream: false,
user: None,
enable_thinking: None,
metadata: HashMap::new(),
}
}
}
impl From<(&Config, Vec<Message>)> for ChatRequest {
fn from((config, messages): (&Config, Vec<Message>)) -> Self {
Self {
messages: messages.into(),
model: Some(config.model.clone()),
temperature: config.temperature,
max_tokens: config.max_tokens,
top_p: config.top_p,
frequency_penalty: config.frequency_penalty,
presence_penalty: config.presence_penalty,
stop: config.stop_sequences.clone(),
tools: None,
tool_choice: None,
stream: false,
user: None,
enable_thinking: None,
metadata: HashMap::new(),
}
}
}
impl From<(&Config, Arc<[Message]>)> for ChatRequest {
fn from((config, messages): (&Config, Arc<[Message]>)) -> Self {
Self {
messages,
model: Some(config.model.clone()),
temperature: config.temperature,
max_tokens: config.max_tokens,
top_p: config.top_p,
frequency_penalty: config.frequency_penalty,
presence_penalty: config.presence_penalty,
stop: config.stop_sequences.clone(),
tools: None,
tool_choice: None,
stream: false,
user: None,
enable_thinking: None,
metadata: HashMap::new(),
}
}
}
impl ChatRequest {
pub fn with_model(mut self, model: impl Into<String>) -> Self {
self.model = Some(model.into());
self
}
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.temperature = Some(temperature);
self
}
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = Some(max_tokens);
self
}
pub fn with_top_p(mut self, top_p: f32) -> Self {
self.top_p = Some(top_p);
self
}
pub fn with_frequency_penalty(mut self, frequency_penalty: f32) -> Self {
self.frequency_penalty = Some(frequency_penalty);
self
}
pub fn with_presence_penalty(mut self, presence_penalty: f32) -> Self {
self.presence_penalty = Some(presence_penalty);
self
}
pub fn with_stop_sequences(
mut self,
stop_sequences: impl IntoIterator<Item = impl Into<String>>,
) -> Self {
self.stop = Some(stop_sequences.into_iter().map(|s| s.into()).collect());
self
}
pub fn with_tools(mut self, tools: Vec<Tool>) -> Self {
self.tools = Some(tools);
self
}
pub fn with_tool_choice(mut self, tool_choice: ToolChoice) -> Self {
self.tool_choice = Some(tool_choice);
self
}
pub fn with_streaming(mut self, stream: bool) -> Self {
self.stream = stream;
self
}
pub fn with_metadata(mut self, metadata: HashMap<String, serde_json::Value>) -> Self {
self.metadata = metadata;
self
}
pub fn with_thinking(mut self, enable_thinking: bool) -> Self {
self.enable_thinking = Some(enable_thinking);
self
}
pub fn validate_has_messages(&self) -> anyhow::Result<()> {
if self.messages.is_empty() {
anyhow::bail!("Chat request must have at least one message");
}
Ok(())
}
pub fn validate(&self) -> anyhow::Result<()> {
self.validate_has_messages()?;
if let Some(temp) = self.temperature
&& !(0.0..=2.0).contains(&temp)
{
anyhow::bail!("Temperature must be between 0.0 and 2.0, got {}", temp);
}
if let Some(top_p) = self.top_p
&& !(0.0..=1.0).contains(&top_p)
{
anyhow::bail!("top_p must be between 0.0 and 1.0, got {}", top_p);
}
if let Some(freq_penalty) = self.frequency_penalty
&& !(-2.0..=2.0).contains(&freq_penalty)
{
anyhow::bail!(
"frequency_penalty must be between -2.0 and 2.0, got {}",
freq_penalty
);
}
if let Some(pres_penalty) = self.presence_penalty
&& !(-2.0..=2.0).contains(&pres_penalty)
{
anyhow::bail!(
"presence_penalty must be between -2.0 and 2.0, got {}",
pres_penalty
);
}
Ok(())
}
pub fn has_tools(&self) -> bool {
self.tools.as_ref().is_some_and(|t| !t.is_empty())
}
pub fn is_streaming(&self) -> bool {
self.stream
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatResponse {
pub message: Message,
pub model: String,
pub usage: Option<Usage>,
pub finish_reason: Option<FinishReason>,
pub created_at: DateTime<Utc>,
pub response_id: Option<String>,
pub metadata: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatChunk {
pub model: String,
pub delta_content: Option<String>,
pub delta_role: Option<crate::chat::MessageRole>,
pub delta_tool_calls: Option<Vec<crate::tools::ToolCall>>,
pub finish_reason: Option<FinishReason>,
pub usage: Option<Usage>,
pub response_id: Option<String>,
pub created_at: DateTime<Utc>,
pub metadata: HashMap<String, serde_json::Value>,
}
impl fmt::Display for ChatResponse {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match serde_json::to_string(self) {
Ok(json) => write!(f, "{}", json),
Err(_) => write!(f, "Error serializing ChatResponse to JSON"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Config {
pub provider: String,
pub model: String,
pub base_url: Option<String>,
#[serde(skip_serializing, default)]
pub api_key: Option<SecretString>,
pub organization: Option<String>,
pub timeout_seconds: Option<u64>,
#[serde(skip)]
pub retry_config: RetryConfig,
pub temperature: Option<f32>,
pub max_tokens: Option<u32>,
pub top_p: Option<f32>,
pub frequency_penalty: Option<f32>,
pub presence_penalty: Option<f32>,
pub stop_sequences: Option<Vec<String>>,
pub metadata: HashMap<String, serde_json::Value>,
}
impl Default for Config {
fn default() -> Self {
Self {
provider: "ollama".to_string(),
model: "gpt-oss:20b".to_string(),
base_url: None,
api_key: None,
organization: None,
timeout_seconds: None,
retry_config: RetryConfig::default(),
temperature: None,
max_tokens: None,
top_p: None,
frequency_penalty: None,
presence_penalty: None,
stop_sequences: None,
metadata: HashMap::new(),
}
}
}
impl Config {
pub fn new(provider: impl Into<String>, model: impl Into<String>) -> Self {
Self {
provider: provider.into(),
model: model.into(),
..Default::default()
}
}
pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
self.base_url = Some(base_url.into());
self
}
pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
self.api_key = Some(SecretString::new(api_key.into().into()));
self
}
pub fn with_organization(mut self, organization: impl Into<String>) -> Self {
self.organization = Some(organization.into());
self
}
pub fn with_timeout(mut self, timeout_seconds: u64) -> Self {
self.timeout_seconds = Some(timeout_seconds);
self
}
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.temperature = Some(temperature);
self
}
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = Some(max_tokens);
self
}
pub fn with_top_p(mut self, top_p: f32) -> Self {
self.top_p = Some(top_p);
self
}
pub fn with_frequency_penalty(mut self, frequency_penalty: f32) -> Self {
self.frequency_penalty = Some(frequency_penalty);
self
}
pub fn with_presence_penalty(mut self, presence_penalty: f32) -> Self {
self.presence_penalty = Some(presence_penalty);
self
}
pub fn with_stop_sequences(
mut self,
stop_sequences: impl IntoIterator<Item = impl Into<String>>,
) -> Self {
self.stop_sequences = Some(stop_sequences.into_iter().map(|s| s.into()).collect());
self
}
pub fn with_metadata(mut self, metadata: HashMap<String, serde_json::Value>) -> Self {
self.metadata = metadata;
self
}
pub fn with_retry_config(mut self, retry_config: RetryConfig) -> Self {
self.retry_config = retry_config;
self
}
}
impl From<(Config, Vec<Message>)> for ChatRequest {
fn from((config, messages): (Config, Vec<Message>)) -> Self {
let mut request = ChatRequest::new(messages).with_model(&config.model);
if let Some(temperature) = config.temperature {
request = request.with_temperature(temperature);
}
if let Some(max_tokens) = config.max_tokens {
request = request.with_max_tokens(max_tokens);
}
if let Some(top_p) = config.top_p {
request.top_p = Some(top_p);
}
if let Some(frequency_penalty) = config.frequency_penalty {
request.frequency_penalty = Some(frequency_penalty);
}
if let Some(presence_penalty) = config.presence_penalty {
request.presence_penalty = Some(presence_penalty);
}
if let Some(stop_sequences) = config.stop_sequences {
request.stop = Some(stop_sequences);
}
request.metadata = config.metadata;
request
}
}
impl Config {
pub fn into_chat_request(self, messages: Vec<Message>) -> ChatRequest {
(self, messages).into()
}
pub fn validate(&self) -> anyhow::Result<()> {
if let Some(temp) = self.temperature
&& !(0.0..=2.0).contains(&temp)
{
anyhow::bail!("Temperature must be between 0.0 and 2.0, got {}", temp);
}
if let Some(top_p) = self.top_p
&& !(0.0..=1.0).contains(&top_p)
{
anyhow::bail!("top_p must be between 0.0 and 1.0, got {}", top_p);
}
if let Some(freq_penalty) = self.frequency_penalty
&& !(-2.0..=2.0).contains(&freq_penalty)
{
anyhow::bail!(
"frequency_penalty must be between -2.0 and 2.0, got {}",
freq_penalty
);
}
if let Some(pres_penalty) = self.presence_penalty
&& !(-2.0..=2.0).contains(&pres_penalty)
{
anyhow::bail!(
"presence_penalty must be between -2.0 and 2.0, got {}",
pres_penalty
);
}
Ok(())
}
}
#[cfg(test)]
mod proptests {
use super::*;
use proptest::prelude::*;
proptest! {
#[test]
fn temperature_validation(temp in -10.0f32..10.0f32) {
let config = Config::new("openai", "gpt-4").with_temperature(temp);
let is_valid = (0.0..=2.0).contains(&temp);
assert_eq!(config.validate().is_ok(), is_valid);
}
#[test]
fn top_p_validation(top_p in -5.0f32..5.0f32) {
let config = Config::new("openai", "gpt-4").with_top_p(top_p);
let is_valid = (0.0..=1.0).contains(&top_p);
assert_eq!(config.validate().is_ok(), is_valid);
}
#[test]
fn frequency_penalty_validation(penalty in -10.0f32..10.0f32) {
let config = Config::new("openai", "gpt-4").with_frequency_penalty(penalty);
let is_valid = (-2.0..=2.0).contains(&penalty);
assert_eq!(config.validate().is_ok(), is_valid);
}
#[test]
fn presence_penalty_validation(penalty in -10.0f32..10.0f32) {
let config = Config::new("openai", "gpt-4").with_presence_penalty(penalty);
let is_valid = (-2.0..=2.0).contains(&penalty);
assert_eq!(config.validate().is_ok(), is_valid);
}
#[test]
fn max_tokens_validation(tokens in 0u32..1000000u32) {
let config = Config::new("openai", "gpt-4").with_max_tokens(tokens);
assert!(config.validate().is_ok());
}
#[test]
fn config_builder_with_string_slice(
provider in ".*",
model in ".*",
base_url in ".*",
) {
let config = Config::new(provider.as_str(), model.as_str())
.with_base_url(base_url.as_str());
assert_eq!(config.provider, provider);
assert_eq!(config.model, model);
assert_eq!(config.base_url, Some(base_url));
}
#[test]
fn config_builder_with_owned_string(
provider in ".*",
model in ".*",
) {
let config = Config::new(provider.clone(), model.clone());
assert_eq!(config.provider, provider);
assert_eq!(config.model, model);
}
#[test]
fn stop_sequences_accepts_various_types(
sequences in prop::collection::vec(".*", 0..10),
) {
let config1 = Config::new("openai", "gpt-4")
.with_stop_sequences(sequences.clone());
assert_eq!(config1.stop_sequences, Some(sequences.clone()));
let str_refs: Vec<&str> = sequences.iter().map(|s| s.as_str()).collect();
let config2 = Config::new("openai", "gpt-4")
.with_stop_sequences(str_refs);
assert_eq!(config2.stop_sequences, Some(sequences.clone()));
if sequences.len() <= 3 {
let arr: Vec<&str> = sequences.iter().map(|s| s.as_str()).collect();
let config3 = Config::new("openai", "gpt-4")
.with_stop_sequences(arr);
assert_eq!(config3.stop_sequences, Some(sequences));
}
}
#[test]
fn builder_chain_preserves_all_values(
provider in ".*",
model in ".*",
temp in 0.0f32..2.0f32,
max_tokens in 0u32..100000u32,
) {
let config = Config::new(provider.as_str(), model.as_str())
.with_temperature(temp)
.with_max_tokens(max_tokens);
assert_eq!(config.provider, provider);
assert_eq!(config.model, model);
assert_eq!(config.temperature, Some(temp));
assert_eq!(config.max_tokens, Some(max_tokens));
assert!(config.validate().is_ok());
}
#[test]
fn chat_request_temperature_validation(
temp in -10.0f32..10.0f32,
msg_count in 1usize..10,
) {
use crate::chat::{Message, MessageRole};
use uuid::Uuid;
let messages: Vec<Message> = (0..msg_count)
.map(|i| Message::new(Uuid::new_v4(), MessageRole::User, format!("message {}", i)))
.collect();
let request = ChatRequest::new(messages).with_temperature(temp);
let is_valid = (0.0..=2.0).contains(&temp);
assert_eq!(request.validate().is_ok(), is_valid);
}
#[test]
fn chat_request_with_string_types(
model in ".*",
) {
use crate::chat::{Message, MessageRole};
use uuid::Uuid;
let msg = Message::new(Uuid::new_v4(), MessageRole::User, "test");
let request1 = ChatRequest::new(vec![msg.clone()])
.with_model(model.as_str());
assert_eq!(request1.model, Some(model.clone()));
let request2 = ChatRequest::new(vec![msg])
.with_model(model.clone());
assert_eq!(request2.model, Some(model));
}
#[test]
fn chat_request_stop_sequences_ergonomics(
sequences in prop::collection::vec(".*", 1..5),
) {
use crate::chat::{Message, MessageRole};
use uuid::Uuid;
let msg = Message::new(Uuid::new_v4(), MessageRole::User, "test");
let request1 = ChatRequest::new(vec![msg.clone()])
.with_stop_sequences(sequences.clone());
assert_eq!(request1.stop, Some(sequences.clone()));
let str_refs: Vec<&str> = sequences.iter().map(|s| s.as_str()).collect();
let request2 = ChatRequest::new(vec![msg])
.with_stop_sequences(str_refs);
assert_eq!(request2.stop, Some(sequences));
}
#[test]
fn chat_request_builder_chain(
model in ".*",
temp in 0.0f32..2.0f32,
max_tokens in 0u32..100000u32,
top_p in 0.0f32..1.0f32,
) {
use crate::chat::{Message, MessageRole};
use uuid::Uuid;
let msg = Message::new(Uuid::new_v4(), MessageRole::User, "test");
let request = ChatRequest::new(vec![msg])
.with_model(model.as_str())
.with_temperature(temp)
.with_max_tokens(max_tokens)
.with_top_p(top_p);
assert_eq!(request.model, Some(model));
assert_eq!(request.temperature, Some(temp));
assert_eq!(request.max_tokens, Some(max_tokens));
assert_eq!(request.top_p, Some(top_p));
assert!(request.validate().is_ok());
}
}
#[test]
fn chat_request_validates_empty_messages() {
let request = ChatRequest::new(vec![]);
assert!(request.validate().is_err());
assert!(request.validate_has_messages().is_err());
}
#[test]
fn chat_request_has_tools() {
use crate::chat::{Message, MessageRole};
use crate::tools::{Function, Tool};
use uuid::Uuid;
let msg = Message::new(Uuid::new_v4(), MessageRole::User, "test");
let request_no_tools = ChatRequest::new(vec![msg.clone()]);
assert!(!request_no_tools.has_tools());
let request_empty_tools = ChatRequest::new(vec![msg.clone()]).with_tools(vec![]);
assert!(!request_empty_tools.has_tools());
let function = Function {
name: "test_function".to_string(),
description: "A test function".to_string(),
parameters: serde_json::json!({}),
};
let tool = Tool::builder().function(function).build();
let request_with_tools = ChatRequest::new(vec![msg]).with_tools(vec![tool]);
assert!(request_with_tools.has_tools());
}
}