use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub struct CompletionRequest {
pub model: String,
pub system: Option<String>,
pub messages: Vec<Message>,
pub tools: Vec<ToolSpec>,
pub tool_choice: ToolChoice,
pub response_format: Option<JsonSchema>,
pub max_tokens: Option<u32>,
pub temperature: Option<f32>,
pub stop: Vec<String>,
pub web_search: bool,
}
impl CompletionRequest {
#[must_use]
pub fn new(model: impl Into<String>) -> Self {
Self {
model: model.into(),
system: None,
messages: Vec::new(),
tools: Vec::new(),
tool_choice: ToolChoice::Auto,
response_format: None,
max_tokens: None,
temperature: None,
stop: Vec::new(),
web_search: false,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: Role,
pub content: Vec<Content>,
}
impl Message {
#[must_use]
pub fn user(text: impl Into<String>) -> Self {
Self {
role: Role::User,
content: vec![Content::Text(text.into())],
}
}
#[must_use]
pub fn assistant(text: impl Into<String>) -> Self {
Self {
role: Role::Assistant,
content: vec![Content::Text(text.into())],
}
}
#[must_use]
pub fn system(text: impl Into<String>) -> Self {
Self {
role: Role::System,
content: vec![Content::Text(text.into())],
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
pub enum Role {
User,
Assistant,
System,
Tool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
pub enum Content {
Text(String),
ToolUse(ToolCall),
ToolResult(ToolResult),
Image(ImageRef),
}
impl Content {
#[must_use]
pub fn text(s: impl Into<String>) -> Self {
Self::Text(s.into())
}
#[must_use]
pub fn tool_use(
id: impl Into<String>,
name: impl Into<String>,
args_json: impl Into<String>,
) -> Self {
Self::ToolUse(ToolCall {
id: id.into(),
name: name.into(),
args_json: args_json.into(),
signature: None,
})
}
#[must_use]
pub fn tool_use_signed(
id: impl Into<String>,
name: impl Into<String>,
args_json: impl Into<String>,
signature: Option<String>,
) -> Self {
Self::ToolUse(ToolCall {
id: id.into(),
name: name.into(),
args_json: args_json.into(),
signature,
})
}
#[must_use]
pub fn tool_result(
tool_call_id: impl Into<String>,
result_json: impl Into<String>,
is_error: bool,
) -> Self {
Self::ToolResult(ToolResult {
tool_call_id: tool_call_id.into(),
result_json: result_json.into(),
is_error,
})
}
#[must_use]
pub fn image(url: impl Into<String>, mime_type: Option<String>) -> Self {
Self::Image(ImageRef {
url: url.into(),
mime_type,
})
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub id: String,
pub name: String,
pub args_json: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub signature: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolResult {
pub tool_call_id: String,
pub result_json: String,
pub is_error: bool,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ImageRef {
pub url: String,
pub mime_type: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolSpec {
pub name: String,
pub description: String,
pub schema_json: serde_json::Value,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub title: Option<String>,
#[serde(default, skip_serializing_if = "std::ops::Not::not")]
pub needs_approval: bool,
}
#[must_use]
pub fn humanize_tool_name(name: &str) -> String {
let words: Vec<&str> = name
.split(|c: char| c == '_' || c == '-' || c.is_whitespace())
.filter(|w| !w.is_empty())
.collect();
if words.is_empty() {
return String::new();
}
let mut out = String::new();
for (i, word) in words.iter().enumerate() {
if i > 0 {
out.push(' ');
}
let lower = word.to_lowercase();
if i == 0 {
let mut chars = lower.chars();
if let Some(first) = chars.next() {
out.extend(first.to_uppercase());
out.push_str(chars.as_str());
}
} else {
out.push_str(&lower);
}
}
out
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
pub enum ToolChoice {
Auto,
None,
Required,
Named(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(transparent)]
pub struct JsonSchema(
pub serde_json::Value,
);
#[cfg(test)]
mod tests {
#![allow(clippy::pedantic, clippy::nursery, missing_docs)]
use serde_json::{Value, json};
use super::*;
#[test]
fn new_sets_model_and_defaults() {
let req = CompletionRequest::new("fast-2");
assert_eq!(req.model, "fast-2");
assert!(req.messages.is_empty());
assert!(req.tools.is_empty());
assert!(req.stop.is_empty());
assert!(req.system.is_none());
assert!(req.max_tokens.is_none());
assert!(req.temperature.is_none());
assert!(req.response_format.is_none());
assert_eq!(req.tool_choice, ToolChoice::Auto);
}
#[test]
fn role_serializes_to_snake_case() {
assert_eq!(serde_json::to_string(&Role::User).unwrap(), r#""user""#);
assert_eq!(
serde_json::to_string(&Role::Assistant).unwrap(),
r#""assistant""#
);
assert_eq!(serde_json::to_string(&Role::System).unwrap(), r#""system""#);
assert_eq!(serde_json::to_string(&Role::Tool).unwrap(), r#""tool""#);
}
#[test]
fn role_round_trips() {
for role in [Role::User, Role::Assistant, Role::System, Role::Tool] {
let json = serde_json::to_string(&role).unwrap();
let back: Role = serde_json::from_str(&json).unwrap();
assert_eq!(back, role);
}
}
#[test]
fn tool_choice_unit_variants_serialize_as_strings() {
assert_eq!(
serde_json::to_string(&ToolChoice::Auto).unwrap(),
r#""auto""#
);
assert_eq!(
serde_json::to_string(&ToolChoice::None).unwrap(),
r#""none""#
);
assert_eq!(
serde_json::to_string(&ToolChoice::Required).unwrap(),
r#""required""#
);
}
#[test]
fn tool_choice_named_serializes_as_object() {
let tc = ToolChoice::Named("my_tool".to_owned());
let v: Value = serde_json::to_value(&tc).unwrap();
assert_eq!(v, json!({"named": "my_tool"}));
}
#[test]
fn tool_choice_round_trips() {
for tc in [
ToolChoice::Auto,
ToolChoice::None,
ToolChoice::Required,
ToolChoice::Named("search".to_owned()),
] {
let json = serde_json::to_string(&tc).unwrap();
let back: ToolChoice = serde_json::from_str(&json).unwrap();
assert_eq!(back, tc);
}
}
#[test]
fn content_text_constructor() {
let c = Content::text("hello");
assert!(matches!(c, Content::Text(s) if s == "hello"));
}
#[test]
fn content_tool_use_constructor() {
let c = Content::tool_use("call-1", "search", r#"{"q":"rust"}"#);
match c {
Content::ToolUse(tu) => {
assert_eq!(tu.id, "call-1");
assert_eq!(tu.name, "search");
assert_eq!(tu.args_json, r#"{"q":"rust"}"#);
}
_ => panic!("wrong variant"),
}
}
#[test]
fn content_tool_result_constructor() {
let c = Content::tool_result("call-1", r#"{"result":"ok"}"#, false);
match c {
Content::ToolResult(tr) => {
assert_eq!(tr.tool_call_id, "call-1");
assert_eq!(tr.result_json, r#"{"result":"ok"}"#);
assert!(!tr.is_error);
}
_ => panic!("wrong variant"),
}
}
#[test]
fn content_image_constructor() {
let c = Content::image("https://example.com/img.png", Some("image/png".to_owned()));
match c {
Content::Image(img) => {
assert_eq!(img.url, "https://example.com/img.png");
assert_eq!(img.mime_type.as_deref(), Some("image/png"));
}
_ => panic!("wrong variant"),
}
}
#[test]
fn message_user_constructor() {
let m = Message::user("hi");
assert_eq!(m.role, Role::User);
assert_eq!(m.content.len(), 1);
assert!(matches!(&m.content[0], Content::Text(s) if s == "hi"));
}
#[test]
fn message_assistant_constructor() {
let m = Message::assistant("hello back");
assert_eq!(m.role, Role::Assistant);
assert_eq!(m.content.len(), 1);
assert!(matches!(&m.content[0], Content::Text(s) if s == "hello back"));
}
#[test]
fn message_system_constructor() {
let m = Message::system("You are helpful.");
assert_eq!(m.role, Role::System);
assert_eq!(m.content.len(), 1);
assert!(matches!(&m.content[0], Content::Text(_)));
}
#[test]
fn tool_use_args_json_preserved_as_opaque_string() {
let original = r#"{"nested":{"key":42},"arr":[1,2,3]}"#;
let c = Content::tool_use("id-42", "complex_tool", original);
let serialized = serde_json::to_string(&c).unwrap();
let back: Content = serde_json::from_str(&serialized).unwrap();
match back {
Content::ToolUse(tu) => assert_eq!(tu.args_json, original),
_ => panic!("wrong variant"),
}
}
#[test]
fn completion_request_round_trips_all_content_variants() {
let mut req = CompletionRequest::new("test-model");
req.system = Some("Be concise.".to_owned());
req.max_tokens = Some(256);
req.temperature = Some(0.7);
req.stop = vec!["<end>".to_owned()];
req.tool_choice = ToolChoice::Named("calculator".to_owned());
req.response_format = Some(JsonSchema(json!({"type": "object"})));
req.tools = vec![ToolSpec {
name: "calculator".to_owned(),
description: "Evaluates math expressions.".to_owned(),
schema_json: json!({"type": "object", "properties": {"expr": {"type": "string"}}}),
title: None,
needs_approval: false,
}];
req.messages = vec![
Message::user("Compute 2+2"),
Message {
role: Role::Assistant,
content: vec![Content::tool_use(
"call-1",
"calculator",
r#"{"expr":"2+2"}"#,
)],
},
Message {
role: Role::Tool,
content: vec![Content::tool_result("call-1", r#"{"value":4}"#, false)],
},
Message {
role: Role::User,
content: vec![Content::image(
"https://example.com/chart.png",
Some("image/png".to_owned()),
)],
},
];
let json_str = serde_json::to_string(&req).unwrap();
let back: CompletionRequest = serde_json::from_str(&json_str).unwrap();
assert_eq!(back.model, "test-model");
assert_eq!(back.system.as_deref(), Some("Be concise."));
assert_eq!(back.max_tokens, Some(256));
assert_eq!(back.messages.len(), 4);
assert_eq!(back.tools.len(), 1);
assert_eq!(back.tool_choice, ToolChoice::Named("calculator".to_owned()));
}
#[test]
fn json_schema_serializes_transparently() {
let schema = JsonSchema(json!({"type": "object", "required": ["name"]}));
let v: Value = serde_json::to_value(&schema).unwrap();
assert_eq!(v["type"], "object");
assert_eq!(v["required"][0], "name");
}
#[test]
fn json_schema_round_trips() {
let inner = json!({"type": "string", "maxLength": 100});
let schema = JsonSchema(inner.clone());
let json_str = serde_json::to_string(&schema).unwrap();
let back: JsonSchema = serde_json::from_str(&json_str).unwrap();
assert_eq!(back.0, inner);
}
#[test]
fn image_ref_default_is_sensible() {
let img = ImageRef::default();
assert!(img.url.is_empty());
assert!(img.mime_type.is_none());
}
#[test]
fn humanize_snake_case() {
assert_eq!(humanize_tool_name("paid_fetch"), "Paid fetch");
assert_eq!(humanize_tool_name("delete_file"), "Delete file");
}
#[test]
fn humanize_kebab_case() {
assert_eq!(humanize_tool_name("delete-file"), "Delete file");
}
#[test]
fn humanize_single_word() {
assert_eq!(humanize_tool_name("calculator"), "Calculator");
}
#[test]
fn humanize_empty() {
assert_eq!(humanize_tool_name(""), "");
}
#[test]
fn humanize_already_spaced_passes_through() {
assert_eq!(humanize_tool_name("Pay for a page"), "Pay for a page");
assert_eq!(humanize_tool_name("delete file"), "Delete file");
}
#[test]
fn tool_spec_carries_optional_title() {
let spec = ToolSpec {
name: "paid_fetch".to_owned(),
description: "d".to_owned(),
schema_json: json!({}),
title: Some("Pay for & fetch a web page".to_owned()),
needs_approval: false,
};
assert_eq!(spec.title.as_deref(), Some("Pay for & fetch a web page"));
}
#[test]
fn tool_spec_carries_needs_approval_flag() {
let spec = ToolSpec {
name: "delete_file".to_owned(),
description: "d".to_owned(),
schema_json: json!({}),
title: None,
needs_approval: true,
};
assert!(spec.needs_approval);
}
#[test]
fn tool_spec_needs_approval_defaults_false_on_deserialize() {
let payload = json!({
"name": "calculator",
"description": "math",
"schema_json": {"type": "object"}
});
let spec: ToolSpec = serde_json::from_value(payload).unwrap();
assert!(
!spec.needs_approval,
"omitted needs_approval must default to false"
);
}
}