use crate::common::parameters::{Name, ParameterProperty, Parameters};
use crate::common::tool::Tool;
use serde::{Deserialize, Serialize};
use super::audio::{AudioFormat, InputAudioNoiseReduction, InputAudioTranscription, Voice};
use super::vad::TurnDetection;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RealtimeTool {
#[serde(rename = "type")]
pub type_name: String,
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub parameters: Option<Parameters>,
}
impl RealtimeTool {
pub fn function<T, U, V>(name: T, description: U, parameters: Vec<(V, ParameterProperty)>) -> Self
where
T: Into<String>,
U: Into<String>,
V: AsRef<str>,
{
let params: Vec<(Name, ParameterProperty)> = parameters.into_iter().map(|(k, v)| (k.as_ref().to_string(), v)).collect();
Self {
type_name: "function".to_string(),
name: name.into(),
description: Some(description.into()),
parameters: Some(Parameters::new(params, None)),
}
}
}
impl From<Tool> for RealtimeTool {
fn from(tool: Tool) -> Self {
if let Some(func) = tool.function {
Self { type_name: "function".to_string(), name: func.name, description: func.description, parameters: func.parameters }
} else {
Self { type_name: tool.type_name, name: tool.name.unwrap_or_default(), description: None, parameters: tool.parameters }
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum Modality {
Text,
Audio,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct SessionConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub modalities: Option<Vec<Modality>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub instructions: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub voice: Option<Voice>,
#[serde(skip_serializing_if = "Option::is_none")]
pub input_audio_format: Option<AudioFormat>,
#[serde(skip_serializing_if = "Option::is_none")]
pub output_audio_format: Option<AudioFormat>,
#[serde(skip_serializing_if = "Option::is_none")]
pub input_audio_transcription: Option<InputAudioTranscription>,
#[serde(skip_serializing_if = "Option::is_none")]
pub input_audio_noise_reduction: Option<InputAudioNoiseReduction>,
#[serde(skip_serializing_if = "Option::is_none")]
pub turn_detection: Option<TurnDetection>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<RealtimeTool>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<ToolChoice>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_response_output_tokens: Option<MaxTokens>,
}
impl SessionConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_modalities(mut self, modalities: Vec<Modality>) -> Self {
self.modalities = Some(modalities);
self
}
pub fn with_instructions(mut self, instructions: impl Into<String>) -> Self {
self.instructions = Some(instructions.into());
self
}
pub fn with_voice(mut self, voice: Voice) -> Self {
self.voice = Some(voice);
self
}
pub fn with_input_audio_format(mut self, format: AudioFormat) -> Self {
self.input_audio_format = Some(format);
self
}
pub fn with_output_audio_format(mut self, format: AudioFormat) -> Self {
self.output_audio_format = Some(format);
self
}
pub fn with_transcription(mut self, config: InputAudioTranscription) -> Self {
self.input_audio_transcription = Some(config);
self
}
pub fn with_turn_detection(mut self, config: TurnDetection) -> Self {
self.turn_detection = Some(config);
self
}
pub fn with_tools(mut self, tools: Vec<Tool>) -> Self {
self.tools = Some(tools.into_iter().map(RealtimeTool::from).collect());
self
}
pub fn with_realtime_tools(mut self, tools: Vec<RealtimeTool>) -> Self {
self.tools = Some(tools);
self
}
pub fn with_tool_choice(mut self, choice: ToolChoice) -> Self {
self.tool_choice = Some(choice);
self
}
pub fn with_temperature(mut self, temp: f32) -> Self {
self.temperature = Some(temp);
self
}
pub fn with_max_tokens(mut self, max: MaxTokens) -> Self {
self.max_response_output_tokens = Some(max);
self
}
}
#[derive(Debug, Clone)]
pub enum MaxTokens {
Count(u32),
Infinite,
}
impl serde::Serialize for MaxTokens {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
match self {
MaxTokens::Count(n) => serializer.serialize_u32(*n),
MaxTokens::Infinite => serializer.serialize_str("inf"),
}
}
}
impl<'de> serde::Deserialize<'de> for MaxTokens {
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de::{self, Visitor};
struct MaxTokensVisitor;
impl<'de> Visitor<'de> for MaxTokensVisitor {
type Value = MaxTokens;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a positive integer or \"inf\"")
}
fn visit_u64<E>(self, value: u64) -> std::result::Result<MaxTokens, E>
where
E: de::Error,
{
Ok(MaxTokens::Count(value as u32))
}
fn visit_str<E>(self, value: &str) -> std::result::Result<MaxTokens, E>
where
E: de::Error,
{
if value == "inf" {
Ok(MaxTokens::Infinite)
} else {
Err(de::Error::custom(format!("unknown value: {}", value)))
}
}
}
deserializer.deserialize_any(MaxTokensVisitor)
}
}
impl From<u32> for MaxTokens {
fn from(count: u32) -> Self {
Self::Count(count)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ToolChoice {
Simple(SimpleToolChoice),
Function(NamedToolChoice),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum SimpleToolChoice {
Auto,
None,
Required,
}
impl Default for ToolChoice {
fn default() -> Self {
Self::Simple(SimpleToolChoice::Auto)
}
}
impl ToolChoice {
pub fn auto() -> Self {
Self::Simple(SimpleToolChoice::Auto)
}
pub fn none() -> Self {
Self::Simple(SimpleToolChoice::None)
}
pub fn required() -> Self {
Self::Simple(SimpleToolChoice::Required)
}
pub fn function(name: impl Into<String>) -> Self {
Self::Function(NamedToolChoice { type_name: "function".to_string(), function: NamedFunction { name: name.into() } })
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NamedToolChoice {
#[serde(rename = "type")]
pub type_name: String,
pub function: NamedFunction,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NamedFunction {
pub name: String,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ResponseCreateConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub modalities: Option<Vec<Modality>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub instructions: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub voice: Option<Voice>,
#[serde(skip_serializing_if = "Option::is_none")]
pub output_audio_format: Option<AudioFormat>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<RealtimeTool>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<ToolChoice>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_output_tokens: Option<MaxTokens>,
#[serde(skip_serializing_if = "Option::is_none")]
pub conversation: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<serde_json::Value>,
}
impl ResponseCreateConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_modalities(mut self, modalities: Vec<Modality>) -> Self {
self.modalities = Some(modalities);
self
}
pub fn with_instructions(mut self, instructions: impl Into<String>) -> Self {
self.instructions = Some(instructions.into());
self
}
pub fn with_voice(mut self, voice: Voice) -> Self {
self.voice = Some(voice);
self
}
pub fn out_of_band(mut self) -> Self {
self.conversation = Some("none".to_string());
self
}
}