use serde::{Deserialize, Serialize};
use serde_json::Value;
pub const CHOICE_REQUEST_TOOL_NAME: &str = "runtime.request_choice";
#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
pub struct Transcript {
pub messages: Vec<Message>,
}
impl Transcript {
pub fn new() -> Self {
Self { messages: vec![] }
}
pub fn with_messages(messages: Vec<Message>) -> Self {
Self { messages }
}
pub fn add_message(&mut self, message: Message) {
self.messages.push(message);
}
pub fn is_empty(&self) -> bool {
self.messages.is_empty()
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(tag = "role", rename_all = "snake_case")]
pub enum Message {
User { content: String },
Assistant { content: String },
AssistantToolCall {
call_id: String,
tool_name: String,
arguments: Value,
},
Tool {
call_id: String,
tool_name: String,
result: Value,
},
}
impl Message {
pub fn user<S: Into<String>>(content: S) -> Self {
Self::User {
content: content.into(),
}
}
pub fn assistant<S: Into<String>>(content: S) -> Self {
Self::Assistant {
content: content.into(),
}
}
pub fn tool<S1: Into<String>, S2: Into<String>>(
call_id: S1,
tool_name: S2,
result: Value,
) -> Self {
Self::Tool {
call_id: call_id.into(),
tool_name: tool_name.into(),
result,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ChoiceSelectionMode {
Single,
Multiple,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ChoiceItem {
pub id: String,
pub label: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ChoiceRequest {
pub prompt: String,
pub selection_mode: ChoiceSelectionMode,
pub items: Vec<ChoiceItem>,
}
impl ChoiceRequest {
pub fn from_value(value: Value) -> Result<Self, serde_json::Error> {
serde_json::from_value(value)
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ToolDefinition {
pub name: String,
pub description: String,
pub input_schema: Value,
}
impl ToolDefinition {
pub fn new<S1: Into<String>, S2: Into<String>>(
name: S1,
description: S2,
input_schema: Value,
) -> Self {
Self {
name: name.into(),
description: description.into(),
input_schema,
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum ToolPolicy {
None,
#[default]
Auto,
Required,
Specific(String),
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
pub struct GenerationConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<Vec<String>>,
}
impl GenerationConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_temperature(mut self, temp: f32) -> Self {
self.temperature = Some(temp);
self
}
pub fn with_max_tokens(mut self, max: u32) -> Self {
self.max_tokens = Some(max);
self
}
pub fn with_top_p(mut self, top_p: f32) -> Self {
self.top_p = Some(top_p);
self
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ToolCall {
pub call_id: String,
pub tool_name: String,
pub arguments: Value,
}
impl ToolCall {
pub fn new<S1: Into<String>, S2: Into<String>>(
call_id: S1,
tool_name: S2,
arguments: Value,
) -> Self {
Self {
call_id: call_id.into(),
tool_name: tool_name.into(),
arguments,
}
}
}
#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
pub struct TokenUsage {
#[serde(skip_serializing_if = "Option::is_none")]
pub input_tokens: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub output_tokens: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub total_tokens: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub cached_input_tokens: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub cache_creation_input_tokens: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub cache_read_input_tokens: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning_output_tokens: Option<u64>,
}
impl TokenUsage {
pub fn new() -> Self {
Self::default()
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ProviderEvent {
Status { message: String },
Output { content: String },
ToolCall { call: ToolCall },
ChoiceRequest { request: ChoiceRequest },
Usage { usage: TokenUsage },
Complete,
Error { source: crate::ProviderError },
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct RuntimeRecord {
pub kind: String,
pub payload: Value,
}
impl RuntimeRecord {
pub fn new<S: Into<String>>(kind: S, payload: Value) -> Self {
Self {
kind: kind.into(),
payload,
}
}
}
#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
pub struct InferenceContext {
pub transcript: Transcript,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub runtime_records: Vec<RuntimeRecord>,
}
impl InferenceContext {
pub fn new() -> Self {
Self::default()
}
pub fn from_transcript(transcript: Transcript) -> Self {
Self {
transcript,
runtime_records: vec![],
}
}
pub fn add_record(&mut self, record: RuntimeRecord) {
self.runtime_records.push(record);
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct InferenceRequest {
pub model: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub instructions: Option<String>,
pub context: InferenceContext,
#[serde(skip_serializing_if = "Vec::is_empty", default)]
pub tools: Vec<ToolDefinition>,
#[serde(default)]
pub tool_policy: ToolPolicy,
#[serde(default)]
pub generation: GenerationConfig,
}
impl InferenceRequest {
pub fn new<S: Into<String>>(model: S, transcript: Transcript) -> Self {
Self {
model: model.into(),
instructions: None,
context: InferenceContext::from_transcript(transcript),
tools: vec![],
tool_policy: ToolPolicy::default(),
generation: GenerationConfig::default(),
}
}
pub fn validate_model(&self) -> crate::ProviderResult<()> {
if self.model.trim().is_empty() {
return Err(crate::ProviderError::invalid_request(
"InferenceRequest.model must be a non-empty model identifier",
));
}
Ok(())
}
pub fn with_instructions<S: Into<String>>(mut self, instructions: S) -> Self {
self.instructions = Some(instructions.into());
self
}
pub fn with_tools(mut self, tools: Vec<ToolDefinition>) -> Self {
self.tools = tools;
self
}
pub fn with_tool_policy(mut self, policy: ToolPolicy) -> Self {
self.tool_policy = policy;
self
}
pub fn with_generation(mut self, generation: GenerationConfig) -> Self {
self.generation = generation;
self
}
}