use serde::ser::SerializeStruct;
use serde::{Deserialize, Serialize, Serializer};
use serde_json::Value;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolDefinition {
#[serde(rename = "type")]
pub kind: String,
pub function: FunctionDefinition,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionDefinition {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
pub parameters: Value,
}
impl ToolDefinition {
pub fn function(
name: impl Into<String>,
description: impl Into<String>,
parameters: Value,
) -> Self {
Self {
kind: "function".to_string(),
function: FunctionDefinition {
name: name.into(),
description: Some(description.into()),
parameters,
},
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ToolChoice {
None,
Auto,
Required,
Function(String),
}
impl Serialize for ToolChoice {
fn serialize<S: Serializer>(&self, ser: S) -> Result<S::Ok, S::Error> {
match self {
ToolChoice::None => ser.serialize_str("none"),
ToolChoice::Auto => ser.serialize_str("auto"),
ToolChoice::Required => ser.serialize_str("required"),
ToolChoice::Function(name) => {
#[derive(Serialize)]
struct Named<'a> {
name: &'a str,
}
let mut s = ser.serialize_struct("ToolChoice", 2)?;
s.serialize_field("type", "function")?;
s.serialize_field("function", &Named { name })?;
s.end()
}
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ToolCall {
pub id: String,
#[serde(rename = "type")]
pub kind: String,
pub function: FunctionCall,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct FunctionCall {
pub name: String,
pub arguments: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage {
pub role: String,
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
}
impl ChatMessage {
pub fn user(content: impl Into<String>) -> Self {
Self {
role: "user".into(),
content: Some(content.into()),
tool_calls: None,
tool_call_id: None,
name: None,
}
}
pub fn system(content: impl Into<String>) -> Self {
Self {
role: "system".into(),
content: Some(content.into()),
tool_calls: None,
tool_call_id: None,
name: None,
}
}
pub fn assistant(content: impl Into<String>) -> Self {
Self {
role: "assistant".into(),
content: Some(content.into()),
tool_calls: None,
tool_call_id: None,
name: None,
}
}
pub fn assistant_tool_calls(tool_calls: Vec<ToolCall>) -> Self {
Self {
role: "assistant".into(),
content: None,
tool_calls: Some(tool_calls),
tool_call_id: None,
name: None,
}
}
pub fn tool_result(
tool_call_id: impl Into<String>,
name: impl Into<String>,
content: impl Into<String>,
) -> Self {
Self {
role: "tool".into(),
content: Some(content.into()),
tool_calls: None,
tool_call_id: Some(tool_call_id.into()),
name: Some(name.into()),
}
}
pub fn content(&self) -> Option<&str> {
self.content.as_deref()
}
}
#[derive(Debug, Clone, Default)]
pub struct ChatOptions {
pub max_tokens: Option<u32>,
pub temperature: Option<f32>,
pub tools: Option<Vec<ToolDefinition>>,
pub tool_choice: Option<ToolChoice>,
}
#[derive(Debug, Clone, Serialize)]
pub struct StreamOptions {
pub include_usage: bool,
}
#[derive(Debug, Clone, Serialize)]
pub struct ChatRequest {
pub model: String,
pub messages: Vec<ChatMessage>,
pub stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream_options: Option<StreamOptions>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<ToolDefinition>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<ToolChoice>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FinishReason {
Stop,
Length,
ToolCalls,
ContentFilter,
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum OaiEvent {
RoleStart(String),
TextDelta(String),
ToolCallStart {
index: u32,
id: String,
name: String,
},
ToolCallArgumentsDelta {
index: u32,
id: String,
delta: String,
},
ToolCallsComplete {
calls: Vec<ToolCall>,
truncated: bool,
},
Usage {
prompt_tokens: u32,
completion_tokens: u32,
cached_tokens: u32,
},
Warning(String),
Done,
}
#[derive(Clone)]
pub struct ProviderConfig {
pub base_url: String,
pub api_key: String,
pub model: String,
pub provider: String,
}
impl std::fmt::Debug for ProviderConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ProviderConfig")
.field("base_url", &self.base_url)
.field("api_key", &"[REDACTED]")
.field("model", &self.model)
.finish()
}
}