use serde::{Deserialize, Serialize};
use utoipa::ToSchema;
pub use infernum_core::{GenerateRequest, SamplingParams};
pub use infernum_core::response::{Choice, GenerateResponse};
pub use infernum_core::request::{EmbedInput, EmbedRequest, EncodingFormat};
pub use infernum_core::response::{EmbedResponse, Embedding, EmbeddingData};
pub use infernum_core::streaming::{StreamChoice, StreamChunk, StreamDelta, TokenStream};
pub use infernum_core::types::{Message, Role};
pub use infernum_core::types::{ToolCall, ToolControl, ToolControlMode, ToolDefinition};
pub use infernum_core::types::{FinishReason, ModelId, RequestId, Usage};
pub use infernum_core::request::PromptInput;
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
pub struct ModelsResponse {
pub models: Vec<ModelListEntry>,
}
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
pub struct ModelListEntry {
pub id: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub architecture: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub context_length: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub quantization: Option<String>,
#[serde(default = "default_owned_by")]
pub owned_by: String,
}
fn default_owned_by() -> String {
"infernum".to_string()
}
impl ModelListEntry {
pub fn new(id: impl Into<String>) -> Self {
Self {
id: id.into(),
architecture: None,
context_length: None,
quantization: None,
owned_by: "infernum".to_string(),
}
}
pub fn with_architecture(mut self, arch: impl Into<String>) -> Self {
self.architecture = Some(arch.into());
self
}
pub fn with_context_length(mut self, length: u32) -> Self {
self.context_length = Some(length);
self
}
pub fn with_quantization(mut self, quant: impl Into<String>) -> Self {
self.quantization = Some(quant.into());
self
}
pub fn with_owned_by(mut self, owner: impl Into<String>) -> Self {
self.owned_by = owner.into();
self
}
}
impl From<&infernum_core::ModelMetadata> for ModelListEntry {
fn from(meta: &infernum_core::ModelMetadata) -> Self {
let architecture = serde_json::to_value(&meta.architecture)
.ok()
.and_then(|v| v.get("type").and_then(|t| t.as_str().map(String::from)));
let quantization = meta.quantization.map(|q| {
serde_json::to_value(q)
.ok()
.and_then(|v| v.as_str().map(String::from))
.unwrap_or_else(|| format!("{q:?}"))
});
Self {
id: meta.id.to_string(),
architecture,
context_length: Some(meta.context_length),
quantization,
owned_by: "infernum".to_string(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
pub struct ErrorResponse {
pub error: ErrorBody,
}
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
pub struct ErrorBody {
pub code: ErrorCode,
pub message: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub details: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, ToSchema)]
#[serde(rename_all = "snake_case")]
pub enum ErrorCode {
InvalidRequest,
ModelNotFound,
ContextOverflow,
RateLimited,
ModelBusy,
InternalError,
}
impl ErrorCode {
pub fn status_code(self) -> u16 {
match self {
Self::InvalidRequest | Self::ContextOverflow => 400,
Self::ModelNotFound => 404,
Self::RateLimited => 429,
Self::ModelBusy => 503,
Self::InternalError => 500,
}
}
}
impl ErrorResponse {
pub fn new(code: ErrorCode, message: impl Into<String>) -> Self {
Self {
error: ErrorBody {
code,
message: message.into(),
details: None,
},
}
}
pub fn with_details(
code: ErrorCode,
message: impl Into<String>,
details: serde_json::Value,
) -> Self {
Self {
error: ErrorBody {
code,
message: message.into(),
details: Some(details),
},
}
}
pub fn invalid_request(message: impl Into<String>) -> Self {
Self::new(ErrorCode::InvalidRequest, message)
}
pub fn model_not_found(message: impl Into<String>) -> Self {
Self::new(ErrorCode::ModelNotFound, message)
}
pub fn context_overflow(tokens: u64, limit: u64) -> Self {
Self::with_details(
ErrorCode::ContextOverflow,
format!("Input ({tokens} tokens) exceeds context length ({limit})"),
serde_json::json!({
"tokens": tokens,
"limit": limit,
}),
)
}
pub fn rate_limited(message: impl Into<String>) -> Self {
Self::new(ErrorCode::RateLimited, message)
}
pub fn model_busy() -> Self {
Self::new(
ErrorCode::ModelBusy,
"Model is currently processing other requests",
)
}
pub fn internal_error(message: impl Into<String>) -> Self {
Self::new(ErrorCode::InternalError, message)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_models_response_serialization() {
let response = ModelsResponse {
models: vec![ModelListEntry::new("llama-3.2-3b")
.with_architecture("llama")
.with_context_length(8192)
.with_quantization("gguf_q4_k_m")],
};
let json = serde_json::to_value(&response).expect("serialize");
assert_eq!(json["models"][0]["id"], "llama-3.2-3b");
assert_eq!(json["models"][0]["architecture"], "llama");
assert_eq!(json["models"][0]["context_length"], 8192);
assert_eq!(json["models"][0]["quantization"], "gguf_q4_k_m");
assert_eq!(json["models"][0]["owned_by"], "infernum");
}
#[test]
fn test_models_response_spec_example() {
let json = r#"{
"models": [
{
"id": "llama-3.2-3b",
"architecture": "llama",
"context_length": 8192,
"quantization": "gguf_q4_k_m",
"owned_by": "infernum"
}
]
}"#;
let parsed: ModelsResponse = serde_json::from_str(json).expect("deserialize");
assert_eq!(parsed.models.len(), 1);
assert_eq!(parsed.models[0].id, "llama-3.2-3b");
assert_eq!(parsed.models[0].architecture.as_deref(), Some("llama"));
}
#[test]
fn test_model_list_entry_minimal() {
let entry = ModelListEntry::new("test-model");
let json = serde_json::to_value(&entry).expect("serialize");
assert_eq!(json["id"], "test-model");
assert_eq!(json["owned_by"], "infernum");
assert!(json.get("architecture").is_none());
assert!(json.get("context_length").is_none());
assert!(json.get("quantization").is_none());
}
#[test]
fn test_model_list_entry_roundtrip() {
let entry = ModelListEntry::new("qwen-2.5-7b")
.with_architecture("qwen2")
.with_context_length(32768)
.with_quantization("f16")
.with_owned_by("custom");
let json = serde_json::to_string(&entry).expect("serialize");
let parsed: ModelListEntry = serde_json::from_str(&json).expect("deserialize");
assert_eq!(parsed.id, "qwen-2.5-7b");
assert_eq!(parsed.architecture.as_deref(), Some("qwen2"));
assert_eq!(parsed.context_length, Some(32768));
assert_eq!(parsed.quantization.as_deref(), Some("f16"));
assert_eq!(parsed.owned_by, "custom");
}
#[test]
fn test_error_response_serialization() {
let err = ErrorResponse::invalid_request("Temperature must be between 0.0 and 2.0");
let json = serde_json::to_value(&err).expect("serialize");
assert_eq!(json["error"]["code"], "invalid_request");
assert_eq!(
json["error"]["message"],
"Temperature must be between 0.0 and 2.0"
);
assert!(json["error"].get("details").is_none());
}
#[test]
fn test_error_response_spec_example() {
let json = r#"{
"error": {
"code": "invalid_request",
"message": "Temperature must be between 0.0 and 2.0",
"details": {
"field": "sampling.temperature",
"value": 3.0,
"constraint": "0.0 <= temperature <= 2.0"
}
}
}"#;
let parsed: ErrorResponse = serde_json::from_str(json).expect("deserialize");
assert_eq!(parsed.error.code, ErrorCode::InvalidRequest);
assert!(parsed.error.details.is_some());
let details = parsed.error.details.as_ref().expect("details");
assert_eq!(details["field"], "sampling.temperature");
}
#[test]
fn test_error_context_overflow() {
let err = ErrorResponse::context_overflow(20000, 16384);
let json = serde_json::to_value(&err).expect("serialize");
assert_eq!(json["error"]["code"], "context_overflow");
let details = &json["error"]["details"];
assert_eq!(details["tokens"], 20000);
assert_eq!(details["limit"], 16384);
}
#[test]
fn test_error_code_status_codes() {
assert_eq!(ErrorCode::InvalidRequest.status_code(), 400);
assert_eq!(ErrorCode::ModelNotFound.status_code(), 404);
assert_eq!(ErrorCode::ContextOverflow.status_code(), 400);
assert_eq!(ErrorCode::RateLimited.status_code(), 429);
assert_eq!(ErrorCode::ModelBusy.status_code(), 503);
assert_eq!(ErrorCode::InternalError.status_code(), 500);
}
#[test]
fn test_error_code_roundtrip() {
for code in [
ErrorCode::InvalidRequest,
ErrorCode::ModelNotFound,
ErrorCode::ContextOverflow,
ErrorCode::RateLimited,
ErrorCode::ModelBusy,
ErrorCode::InternalError,
] {
let json = serde_json::to_value(code).expect("serialize");
let parsed: ErrorCode = serde_json::from_value(json).expect("deserialize");
assert_eq!(parsed, code);
}
}
#[test]
fn test_core_types_accessible() {
let msg = Message::user("Hello");
assert_eq!(msg.role, Role::User);
assert_eq!(msg.content, "Hello");
let tool = ToolDefinition {
name: "read_file".to_string(),
description: Some("Read a file".to_string()),
parameters: None,
strict: None,
};
assert_eq!(tool.name, "read_file");
}
#[test]
fn test_generate_request_accessible() {
let req = GenerateRequest::new("Hello, world!")
.with_sampling(SamplingParams::default().with_max_tokens(100));
match &req.prompt {
PromptInput::Text(t) => assert_eq!(t, "Hello, world!"),
_ => panic!("Expected text prompt"),
}
}
#[test]
fn test_embed_request_accessible() {
let req = EmbedRequest {
request_id: RequestId::new(),
model: Some(ModelId::from("nomic-embed")),
input: EmbedInput::Single("test".to_string()),
encoding_format: EncodingFormat::Float,
dimensions: None,
};
assert!(req.model.is_some());
}
}