use std::time::Duration;
use futures::stream::BoxStream;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct ModelId(pub String);
impl ModelId {
pub fn new(id: impl Into<String>) -> Self {
Self(id.into())
}
pub fn as_str(&self) -> &str {
&self.0
}
}
impl Default for ModelId {
fn default() -> Self {
Self("default".into())
}
}
impl std::fmt::Display for ModelId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.0)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: Role,
pub content: Vec<ContentBlock>,
}
impl Message {
pub fn new(role: Role, content: Vec<ContentBlock>) -> Self {
Self { role, content }
}
pub fn user(content: impl Into<String>) -> Self {
Self {
role: Role::User,
content: vec![ContentBlock::text(content)],
}
}
pub fn assistant(content: impl Into<String>) -> Self {
Self {
role: Role::Assistant,
content: vec![ContentBlock::text(content)],
}
}
pub fn system(content: impl Into<String>) -> Self {
Self {
role: Role::System,
content: vec![ContentBlock::text(content)],
}
}
pub fn tool_result(
tool_use_id: impl Into<String>,
content: impl Into<String>,
is_error: bool,
) -> Self {
Self {
role: Role::User,
content: vec![ContentBlock::ToolResult {
tool_use_id: tool_use_id.into(),
content: content.into(),
is_error,
}],
}
}
pub fn text(&self) -> String {
self.content
.iter()
.filter_map(|b| match b {
ContentBlock::Text { text } => Some(text.as_str()),
_ => None,
})
.collect::<Vec<_>>()
.join("")
}
pub fn tool_uses(&self) -> impl Iterator<Item = &ContentBlock> {
self.content
.iter()
.filter(|b| matches!(b, ContentBlock::ToolUse { .. }))
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ContentBlock {
Text {
text: String,
},
ToolUse {
id: String,
name: String,
input: serde_json::Value,
},
ToolResult {
tool_use_id: String,
content: String,
#[serde(default)]
is_error: bool,
},
}
impl ContentBlock {
pub fn text(text: impl Into<String>) -> Self {
Self::Text { text: text.into() }
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum Role {
System,
User,
Assistant,
}
#[derive(Debug, Clone, Default)]
pub struct Request {
pub messages: Vec<Message>,
pub model: ModelId,
pub max_tokens: Option<u32>,
pub temperature: Option<f32>,
pub system: Option<String>,
pub stop: Vec<String>,
pub tools: Vec<ToolDefinition>,
}
#[derive(Debug, Clone)]
pub struct Response {
pub content: String,
pub tool_calls: Vec<ToolUse>,
pub usage: Usage,
pub model: ModelId,
pub finish_reason: FinishReason,
pub latency: Duration,
pub raw: serde_json::Value,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ToolDefinition {
pub name: String,
pub description: String,
pub input_schema: serde_json::Value,
}
impl ToolDefinition {
pub fn new(
name: impl Into<String>,
description: impl Into<String>,
input_schema: serde_json::Value,
) -> Self {
Self {
name: name.into(),
description: description.into(),
input_schema,
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ToolUse {
pub id: String,
pub name: String,
pub input: serde_json::Value,
}
#[derive(Debug, Clone, Default)]
pub struct Usage {
pub input_tokens: u32,
pub output_tokens: u32,
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub enum FinishReason {
#[default]
Stop,
MaxTokens,
ContentFilter,
ToolUse,
Other(String),
}
impl std::fmt::Display for Response {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "[{}] ({:.0?})", self.model, self.latency)?;
writeln!(f, "{}", self.content)?;
write!(
f,
"tokens: {} in / {} out | finish: {:?}",
self.usage.input_tokens, self.usage.output_tokens, self.finish_reason
)
}
}
#[derive(Debug, Clone)]
pub enum StreamChunk {
Delta(String),
Done { usage: Option<Usage> },
Error(String),
}
pub type StreamResponse<'a> = BoxStream<'a, StreamChunk>;
#[derive(Debug, Clone)]
pub struct EmbedRequest {
pub model: ModelId,
pub input: Vec<String>,
}
#[derive(Debug, Clone)]
pub struct Embedding {
pub vectors: Vec<Vec<f32>>,
pub model: ModelId,
pub usage: Usage,
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashSet;
use std::time::Duration;
#[test]
fn model_id_new_and_as_str() {
let m = ModelId::new("gpt-4");
assert_eq!(m.as_str(), "gpt-4");
}
#[test]
fn model_id_default() {
assert_eq!(ModelId::default().as_str(), "default");
}
#[test]
fn model_id_display() {
let m = ModelId::new("claude-3");
assert_eq!(format!("{m}"), "claude-3");
}
#[test]
fn model_id_eq_and_hash() {
let a = ModelId::new("x");
let b = ModelId::new("x");
let c = ModelId::new("y");
assert_eq!(a, b);
assert_ne!(a, c);
let mut set = HashSet::new();
set.insert(a);
set.insert(b);
assert_eq!(set.len(), 1);
}
#[test]
fn model_id_serde_roundtrip() {
let m = ModelId::new("llama3.2");
let json = serde_json::to_string(&m).unwrap();
let back: ModelId = serde_json::from_str(&json).unwrap();
assert_eq!(m, back);
}
#[test]
fn message_user() {
let m = Message::user("hi");
assert_eq!(m.role, Role::User);
assert_eq!(m.text(), "hi");
}
#[test]
fn message_assistant() {
let m = Message::assistant("ok");
assert_eq!(m.role, Role::Assistant);
assert_eq!(m.text(), "ok");
}
#[test]
fn message_system() {
let m = Message::system("you are helpful");
assert_eq!(m.role, Role::System);
assert_eq!(m.text(), "you are helpful");
}
#[test]
fn message_text_concatenates_and_ignores_non_text() {
let m = Message::new(
Role::Assistant,
vec![
ContentBlock::text("a"),
ContentBlock::ToolUse {
id: "t1".into(),
name: "x".into(),
input: serde_json::json!({}),
},
ContentBlock::text("b"),
],
);
assert_eq!(m.text(), "ab");
assert_eq!(m.tool_uses().count(), 1);
}
#[test]
fn message_tool_result_helper() {
let m = Message::tool_result("tu_1", "result body", false);
assert_eq!(m.role, Role::User);
match &m.content[0] {
ContentBlock::ToolResult {
tool_use_id,
content,
is_error,
} => {
assert_eq!(tool_use_id, "tu_1");
assert_eq!(content, "result body");
assert!(!is_error);
}
other => panic!("expected ToolResult, got {other:?}"),
}
}
#[test]
fn content_block_serde_roundtrip() {
for block in [
ContentBlock::text("hi"),
ContentBlock::ToolUse {
id: "id".into(),
name: "search".into(),
input: serde_json::json!({"q": "x"}),
},
ContentBlock::ToolResult {
tool_use_id: "id".into(),
content: "ok".into(),
is_error: false,
},
] {
let json = serde_json::to_string(&block).unwrap();
let back: ContentBlock = serde_json::from_str(&json).unwrap();
assert_eq!(block, back);
}
}
#[test]
fn role_serde_roundtrip() {
for (role, expected) in [
(Role::User, "\"user\""),
(Role::Assistant, "\"assistant\""),
(Role::System, "\"system\""),
] {
let json = serde_json::to_string(&role).unwrap();
assert_eq!(json, expected);
let back: Role = serde_json::from_str(&json).unwrap();
assert_eq!(back, role);
}
}
#[test]
fn request_default() {
let r = Request::default();
assert!(r.messages.is_empty());
assert_eq!(r.model, ModelId::default());
assert!(r.max_tokens.is_none());
assert!(r.temperature.is_none());
assert!(r.system.is_none());
assert!(r.stop.is_empty());
}
#[test]
fn response_display() {
let resp = Response {
content: "Hello!".into(),
tool_calls: vec![],
usage: Usage {
input_tokens: 10,
output_tokens: 5,
},
model: ModelId::new("test-model"),
finish_reason: FinishReason::Stop,
latency: Duration::from_millis(1234),
raw: serde_json::Value::Null,
};
let s = format!("{resp}");
assert!(s.contains("test-model"));
assert!(s.contains("Hello!"));
assert!(s.contains("10 in"));
assert!(s.contains("5 out"));
assert!(s.contains("Stop"));
assert!(s.contains("1"));
}
#[test]
fn finish_reason_default_is_stop() {
assert_eq!(FinishReason::default(), FinishReason::Stop);
}
#[test]
fn finish_reason_variants() {
assert_eq!(FinishReason::Stop, FinishReason::Stop);
assert_ne!(FinishReason::Stop, FinishReason::MaxTokens);
assert_ne!(FinishReason::MaxTokens, FinishReason::ContentFilter);
let other = FinishReason::Other("custom".into());
assert_eq!(other, FinishReason::Other("custom".into()));
assert_ne!(other, FinishReason::Other("different".into()));
}
#[test]
fn usage_default() {
let u = Usage::default();
assert_eq!(u.input_tokens, 0);
assert_eq!(u.output_tokens, 0);
}
#[test]
fn stream_chunk_debug() {
let _ = format!("{:?}", StreamChunk::Delta("hi".into()));
let _ = format!("{:?}", StreamChunk::Done { usage: None });
let _ = format!(
"{:?}",
StreamChunk::Done {
usage: Some(Usage::default())
}
);
let _ = format!("{:?}", StreamChunk::Error("err".into()));
}
#[test]
fn embed_request_construction() {
let r = EmbedRequest {
model: ModelId::new("nomic"),
input: vec!["hello".into(), "world".into()],
};
assert_eq!(r.model.as_str(), "nomic");
assert_eq!(r.input.len(), 2);
}
#[test]
fn embedding_construction() {
let e = Embedding {
vectors: vec![vec![0.1, 0.2], vec![0.3, 0.4]],
model: ModelId::new("nomic"),
usage: Usage {
input_tokens: 4,
output_tokens: 0,
},
};
assert_eq!(e.vectors.len(), 2);
assert_eq!(e.vectors[0].len(), 2);
assert_eq!(e.model.as_str(), "nomic");
}
}