use serde::{Deserialize, Serialize};
use serde_json::Value;
use super::LicenseSetting;
use crate::ChatMessage;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct CreateRequest {
pub model: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub from: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub template: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub license: Option<LicenseSetting>,
#[serde(skip_serializing_if = "Option::is_none")]
pub system: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub parameters: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub messages: Option<Vec<ChatMessage>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub quantize: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream: Option<bool>,
}
impl CreateRequest {
pub fn new(model: impl Into<String>) -> Self {
Self {
model: model.into(),
from: None,
template: None,
license: None,
system: None,
parameters: None,
messages: None,
quantize: None,
stream: Some(false), }
}
pub fn from_model(model: impl Into<String>, from: impl Into<String>) -> Self {
Self {
model: model.into(),
from: Some(from.into()),
template: None,
license: None,
system: None,
parameters: None,
messages: None,
quantize: None,
stream: Some(false), }
}
pub fn with_from(mut self, from: impl Into<String>) -> Self {
self.from = Some(from.into());
self
}
pub fn with_template(mut self, template: impl Into<String>) -> Self {
self.template = Some(template.into());
self
}
pub fn with_license(mut self, license: impl Into<LicenseSetting>) -> Self {
self.license = Some(license.into());
self
}
pub fn with_system(mut self, system: impl Into<String>) -> Self {
self.system = Some(system.into());
self
}
pub fn with_parameters(mut self, parameters: Value) -> Self {
self.parameters = Some(parameters);
self
}
pub fn with_messages<I>(mut self, messages: I) -> Self
where
I: IntoIterator<Item = ChatMessage>,
{
self.messages = Some(messages.into_iter().collect());
self
}
pub fn with_message(mut self, message: ChatMessage) -> Self {
self.messages.get_or_insert_with(Vec::new).push(message);
self
}
pub fn with_quantize(mut self, quantize: impl Into<String>) -> Self {
self.quantize = Some(quantize.into());
self
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_create_request_new() {
let request = CreateRequest::new("mario");
assert_eq!(request.model, "mario");
assert!(request.from.is_none());
assert!(request.system.is_none());
assert_eq!(request.stream, Some(false));
}
#[test]
fn test_create_request_from_model() {
let request = CreateRequest::from_model("mario", "qwen3:0.6b");
assert_eq!(request.model, "mario");
assert_eq!(request.from, Some("qwen3:0.6b".to_string()));
assert_eq!(request.stream, Some(false));
}
#[test]
fn test_create_request_builder_pattern() {
let request = CreateRequest::from_model("mario", "qwen3:0.6b")
.with_system("You are Mario from Super Mario Bros.")
.with_template("{{ .System }}\n\n{{ .Prompt }}")
.with_license("MIT")
.with_quantize("q4_K_M");
assert_eq!(request.model, "mario");
assert_eq!(request.from, Some("qwen3:0.6b".to_string()));
assert_eq!(
request.system,
Some("You are Mario from Super Mario Bros.".to_string())
);
assert_eq!(
request.template,
Some("{{ .System }}\n\n{{ .Prompt }}".to_string())
);
assert_eq!(
request.license,
Some(LicenseSetting::Single("MIT".to_string()))
);
assert_eq!(request.quantize, Some("q4_K_M".to_string()));
}
#[test]
fn test_create_request_with_messages() {
let request = CreateRequest::from_model("mario", "qwen3:0.6b").with_messages([
ChatMessage::user("Who are you?"),
ChatMessage::assistant("It's-a me, Mario!"),
]);
assert!(request.messages.is_some());
let messages = request.messages.unwrap();
assert_eq!(messages.len(), 2);
}
#[test]
fn test_create_request_with_single_message() {
let request = CreateRequest::from_model("mario", "qwen3:0.6b")
.with_message(ChatMessage::user("Who are you?"))
.with_message(ChatMessage::assistant("It's-a me, Mario!"));
assert!(request.messages.is_some());
let messages = request.messages.unwrap();
assert_eq!(messages.len(), 2);
}
#[test]
fn test_create_request_with_parameters() {
let request = CreateRequest::from_model("mario", "qwen3:0.6b").with_parameters(json!({
"temperature": 0.8,
"top_k": 40
}));
assert!(request.parameters.is_some());
let params = request.parameters.unwrap();
assert_eq!(params["temperature"], 0.8);
assert_eq!(params["top_k"], 40);
}
#[test]
fn test_create_request_serialization() {
let request = CreateRequest::from_model("mario", "qwen3:0.6b").with_system("You are Mario");
let json = serde_json::to_string(&request).unwrap();
assert!(json.contains("\"model\":\"mario\""));
assert!(json.contains("\"from\":\"qwen3:0.6b\""));
assert!(json.contains("\"system\":\"You are Mario\""));
assert!(json.contains("\"stream\":false"));
}
#[test]
fn test_create_request_serialization_skips_none() {
let request = CreateRequest::new("mario");
let json = serde_json::to_string(&request).unwrap();
assert!(json.contains("\"model\":\"mario\""));
assert!(!json.contains("\"from\""));
assert!(!json.contains("\"system\""));
assert!(!json.contains("\"template\""));
assert!(!json.contains("\"license\""));
assert!(!json.contains("\"parameters\""));
assert!(!json.contains("\"messages\""));
assert!(!json.contains("\"quantize\""));
}
}