use anyhow::Result;
use async_trait::async_trait;
use futures::stream::BoxStream;
use serde::{Deserialize, Serialize};
use crate::message::{ChatResponse, Message, StreamChunk};
use crate::tool::Tool;
#[async_trait]
pub trait Provider: Send + Sync {
fn name(&self) -> &str;
fn max_output_tokens(&self) -> Option<u32> {
None }
async fn chat(
&self,
messages: &[Message],
tools: Option<&[Tool]>,
options: &ChatOptions,
) -> Result<ChatResponse>;
fn stream_chat<'a>(
&'a self,
messages: &'a [Message],
tools: Option<&'a [Tool]>,
options: &'a ChatOptions,
) -> BoxStream<'a, Result<StreamChunk>>;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatOptions {
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub system: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub model: Option<String>,
}
impl Default for ChatOptions {
fn default() -> Self {
Self {
temperature: Some(0.7),
max_tokens: Some(4096),
top_p: None,
stop: None,
system: None,
model: None,
}
}
}
impl ChatOptions {
pub fn new() -> Self {
Self::default()
}
pub fn temperature(mut self, temperature: f32) -> Self {
self.temperature = Some(temperature);
self
}
pub fn max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = Some(max_tokens);
self
}
pub fn system<S: Into<String>>(mut self, system: S) -> Self {
self.system = Some(system.into());
self
}
pub fn top_p(mut self, top_p: f32) -> Self {
self.top_p = Some(top_p);
self
}
pub fn model<S: Into<String>>(mut self, model: S) -> Self {
self.model = Some(model.into());
self
}
pub fn deterministic(max_tokens: u32) -> Self {
Self {
temperature: Some(0.0),
max_tokens: Some(max_tokens),
..Default::default()
}
}
pub fn factual(max_tokens: u32) -> Self {
Self {
temperature: Some(0.1),
max_tokens: Some(max_tokens),
top_p: Some(0.9),
..Default::default()
}
}
pub fn creative(max_tokens: u32) -> Self {
Self {
temperature: Some(0.3),
max_tokens: Some(max_tokens),
..Default::default()
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_chat_options_default() {
let opts = ChatOptions::default();
assert_eq!(opts.temperature, Some(0.7));
assert_eq!(opts.max_tokens, Some(4096));
}
#[test]
fn test_chat_options_builder() {
let opts = ChatOptions::new()
.temperature(0.5)
.max_tokens(2048)
.system("Test");
assert_eq!(opts.temperature, Some(0.5));
assert_eq!(opts.max_tokens, Some(2048));
assert_eq!(opts.system, Some("Test".to_string()));
}
#[test]
fn test_chat_options_deterministic() {
let opts = ChatOptions::deterministic(50);
assert_eq!(opts.temperature, Some(0.0));
assert_eq!(opts.max_tokens, Some(50));
}
#[test]
fn test_chat_options_factual() {
let opts = ChatOptions::factual(200);
assert_eq!(opts.temperature, Some(0.1));
assert_eq!(opts.max_tokens, Some(200));
assert_eq!(opts.top_p, Some(0.9));
}
#[test]
fn test_chat_options_creative() {
let opts = ChatOptions::creative(400);
assert_eq!(opts.temperature, Some(0.3));
assert_eq!(opts.max_tokens, Some(400));
}
}