use serde::{Deserialize, Serialize};
#[non_exhaustive]
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct SamplingRequest {
pub messages: Vec<SamplingMessage>,
#[serde(rename = "modelPreferences", skip_serializing_if = "Option::is_none")]
pub model_preferences: Option<ModelPreferences>,
#[serde(rename = "systemPrompt", skip_serializing_if = "Option::is_none")]
pub system_prompt: Option<String>,
#[serde(rename = "maxTokens")]
pub max_tokens: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(rename = "stopSequences", skip_serializing_if = "Option::is_none")]
pub stop_sequences: Option<Vec<String>>,
}
impl SamplingRequest {
pub fn new(messages: Vec<SamplingMessage>, max_tokens: u32) -> Self {
Self {
messages,
model_preferences: None,
system_prompt: None,
max_tokens,
temperature: None,
stop_sequences: None,
}
}
}
#[non_exhaustive]
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct SamplingMessage {
pub role: String,
pub content: SamplingContent,
}
impl SamplingMessage {
pub fn new(role: impl Into<String>, content: SamplingContent) -> Self {
Self {
role: role.into(),
content,
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "lowercase")]
#[non_exhaustive]
pub enum SamplingContent {
Text {
text: String,
},
}
#[non_exhaustive]
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct ModelPreferences {
#[serde(skip_serializing_if = "Option::is_none")]
pub hints: Option<Vec<ModelHint>>,
#[serde(rename = "costPriority", skip_serializing_if = "Option::is_none")]
pub cost_priority: Option<f32>,
#[serde(rename = "speedPriority", skip_serializing_if = "Option::is_none")]
pub speed_priority: Option<f32>,
#[serde(
rename = "intelligencePriority",
skip_serializing_if = "Option::is_none"
)]
pub intelligence_priority: Option<f32>,
}
#[non_exhaustive]
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ModelHint {
pub name: String,
}
#[non_exhaustive]
#[derive(Clone, Debug, Deserialize)]
pub struct SamplingResponse {
pub role: String,
pub content: SamplingContent,
pub model: String,
#[serde(rename = "stopReason")]
pub stop_reason: Option<String>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn sampling_request_serialises_camelcase_field_names() {
let req = SamplingRequest {
messages: vec![SamplingMessage {
role: "user".into(),
content: SamplingContent::Text {
text: "hello".into(),
},
}],
model_preferences: Some(ModelPreferences {
hints: Some(vec![ModelHint {
name: "claude-3-5-sonnet".into(),
}]),
cost_priority: Some(0.2),
speed_priority: Some(0.5),
intelligence_priority: Some(0.9),
}),
system_prompt: Some("be concise".into()),
max_tokens: 256,
temperature: Some(0.7),
stop_sequences: Some(vec!["STOP".into()]),
};
let json = serde_json::to_value(&req).expect("serialises");
assert!(json.get("maxTokens").is_some(), "maxTokens key present");
assert!(json.get("max_tokens").is_none(), "no snake_case leak");
assert!(
json.get("systemPrompt").is_some(),
"systemPrompt key present"
);
assert!(
json.get("modelPreferences").is_some(),
"modelPreferences key present"
);
assert!(
json.get("stopSequences").is_some(),
"stopSequences key present"
);
let prefs = &json["modelPreferences"];
assert!(prefs.get("costPriority").is_some(), "costPriority present");
assert!(
prefs.get("speedPriority").is_some(),
"speedPriority present"
);
assert!(
prefs.get("intelligencePriority").is_some(),
"intelligencePriority present"
);
let minimal = SamplingRequest {
messages: vec![],
model_preferences: None,
system_prompt: None,
max_tokens: 16,
temperature: None,
stop_sequences: None,
};
let minimal_json = serde_json::to_value(&minimal).expect("serialises");
assert!(minimal_json.get("systemPrompt").is_none());
assert!(minimal_json.get("modelPreferences").is_none());
assert!(minimal_json.get("stopSequences").is_none());
assert!(minimal_json.get("temperature").is_none());
}
#[test]
fn sampling_response_deserialises_text_content() {
let fixture = r#"{"role":"assistant","content":{"type":"text","text":"hi"},"model":"test","stopReason":"endTurn"}"#;
let resp: SamplingResponse = serde_json::from_str(fixture).expect("parses");
assert_eq!(resp.role, "assistant");
assert_eq!(resp.model, "test");
assert_eq!(resp.stop_reason.as_deref(), Some("endTurn"));
match resp.content {
SamplingContent::Text { text } => assert_eq!(text, "hi"),
}
}
#[test]
fn sampling_content_image_fails_to_deserialise() {
let fixture = r#"{"type":"image","data":"...","mimeType":"image/png"}"#;
let result: Result<SamplingContent, _> = serde_json::from_str(fixture);
assert!(
result.is_err(),
"image-tagged content must surface as unknown-variant error"
);
}
#[test]
fn sampling_request_new_defaults_optional_fields_to_none() {
let msg = SamplingMessage::new("user", SamplingContent::Text { text: "hi".into() });
let req = SamplingRequest::new(vec![msg], 128);
assert_eq!(req.messages.len(), 1);
assert_eq!(req.max_tokens, 128);
assert!(req.model_preferences.is_none());
assert!(req.system_prompt.is_none());
assert!(req.temperature.is_none());
assert!(req.stop_sequences.is_none());
}
#[test]
fn sampling_message_new_accepts_str_and_string_role() {
let m1 = SamplingMessage::new("user", SamplingContent::Text { text: "a".into() });
let m2 = SamplingMessage::new(
String::from("assistant"),
SamplingContent::Text { text: "b".into() },
);
assert_eq!(m1.role, "user");
assert_eq!(m2.role, "assistant");
assert!(matches!(m1.content, SamplingContent::Text { .. }));
}
}