use serde::{Deserialize, Serialize};
use std::time::Duration;
#[cfg(feature = "llm")]
use std::time::Instant;
#[derive(Debug, Clone, Deserialize)]
pub struct StreamDelta {
pub content: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct StreamChoice {
pub delta: StreamDelta,
pub finish_reason: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct StreamChunk {
pub choices: Vec<StreamChoice>,
pub usage: Option<Usage>,
}
#[derive(Debug, Clone)]
pub struct StreamedChatResponse {
pub content: String,
pub latency: Duration,
pub ttft: Duration,
pub token_timestamps: Vec<Duration>,
pub usage: Option<Usage>,
pub finish_reason: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum Role {
System,
User,
Assistant,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage {
pub role: Role,
pub content: String,
}
#[derive(Debug, Clone, Serialize)]
pub struct ChatRequest {
pub model: String,
pub messages: Vec<ChatMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream: Option<bool>,
}
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ChatResponseChoice {
pub index: u32,
pub message: ChatMessage,
pub finish_reason: Option<String>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ChatResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<ChatResponseChoice>,
pub usage: Option<Usage>,
#[serde(default)]
pub brick_trace: Option<BrickTrace>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct BrickTrace {
pub level: String,
pub operations: usize,
pub total_time_us: u64,
pub breakdown: Vec<BrickTraceOp>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct BrickTraceOp {
pub name: String,
pub time_us: u64,
#[serde(default)]
pub details: Option<String>,
}
#[derive(Debug, Clone)]
pub struct TimedChatResponse {
pub response: ChatResponse,
pub latency: Duration,
pub ttfb: Duration,
pub brick_trace: Option<BrickTrace>,
}
#[cfg(feature = "llm")]
#[derive(Debug, thiserror::Error)]
pub enum LlmClientError {
#[error("HTTP error: {0}")]
Http(#[from] reqwest::Error),
#[error("API error {status}: {body}")]
ApiError {
status: u16,
body: String,
},
#[error("Health check failed: {0}")]
HealthCheckFailed(String),
#[error("Health check timed out after {0:?}")]
HealthCheckTimeout(Duration),
}
#[cfg(feature = "llm")]
#[derive(Debug, Clone)]
pub struct LlmClient {
base_url: String,
client: reqwest::Client,
model: String,
}
#[cfg(feature = "llm")]
impl LlmClient {
pub fn new(base_url: impl Into<String>, model: impl Into<String>) -> Self {
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(120))
.build()
.unwrap_or_default();
Self {
base_url: base_url.into().trim_end_matches('/').to_string(),
client,
model: model.into(),
}
}
pub fn with_client(
base_url: impl Into<String>,
model: impl Into<String>,
client: reqwest::Client,
) -> Self {
Self {
base_url: base_url.into().trim_end_matches('/').to_string(),
client,
model: model.into(),
}
}
pub fn base_url(&self) -> &str {
&self.base_url
}
pub fn model(&self) -> &str {
&self.model
}
pub async fn chat_completion(
&self,
messages: Vec<ChatMessage>,
temperature: Option<f64>,
max_tokens: Option<u32>,
) -> Result<TimedChatResponse, LlmClientError> {
let request = ChatRequest {
model: self.model.clone(),
messages,
temperature,
max_tokens,
stream: Some(false),
};
let url = format!("{}/v1/chat/completions", self.base_url);
let start = Instant::now();
let resp = self.client.post(&url).json(&request).send().await?;
let ttfb = start.elapsed();
let status = resp.status();
if !status.is_success() {
let body = resp.text().await.unwrap_or_default();
return Err(LlmClientError::ApiError {
status: status.as_u16(),
body,
});
}
let response: ChatResponse = resp.json().await?;
let latency = start.elapsed();
let brick_trace = response.brick_trace.clone();
Ok(TimedChatResponse {
response,
latency,
ttfb,
brick_trace,
})
}
pub async fn send(&self, request: &ChatRequest) -> Result<TimedChatResponse, LlmClientError> {
let url = format!("{}/v1/chat/completions", self.base_url);
let start = Instant::now();
let actual_request;
let req = if request.model.is_empty() {
actual_request = ChatRequest {
model: self.model.clone(),
..request.clone()
};
&actual_request
} else {
request
};
let resp = self.client.post(&url).json(req).send().await?;
let ttfb = start.elapsed();
let status = resp.status();
if !status.is_success() {
let body = resp.text().await.unwrap_or_default();
return Err(LlmClientError::ApiError {
status: status.as_u16(),
body,
});
}
let response: ChatResponse = resp.json().await?;
let latency = start.elapsed();
let brick_trace = response.brick_trace.clone();
Ok(TimedChatResponse {
response,
latency,
ttfb,
brick_trace,
})
}
pub async fn send_with_trace(
&self,
request: &ChatRequest,
trace_level: &str,
) -> Result<TimedChatResponse, LlmClientError> {
let url = format!("{}/v1/chat/completions", self.base_url);
let start = Instant::now();
let actual_request;
let req = if request.model.is_empty() {
actual_request = ChatRequest {
model: self.model.clone(),
..request.clone()
};
&actual_request
} else {
request
};
let resp = self
.client
.post(&url)
.header("X-Trace-Level", trace_level)
.json(req)
.send()
.await?;
let ttfb = start.elapsed();
let status = resp.status();
if !status.is_success() {
let body = resp.text().await.unwrap_or_default();
return Err(LlmClientError::ApiError {
status: status.as_u16(),
body,
});
}
let response: ChatResponse = resp.json().await?;
let latency = start.elapsed();
let brick_trace = response.brick_trace.clone();
Ok(TimedChatResponse {
response,
latency,
ttfb,
brick_trace,
})
}
pub async fn health_check(&self) -> Result<bool, LlmClientError> {
for path in &["/health", "/v1/models", "/"] {
let url = format!("{}{path}", self.base_url);
if let Ok(resp) = self.client.get(&url).send().await {
if resp.status().is_success() {
return Ok(true);
}
}
}
Err(LlmClientError::HealthCheckFailed(format!(
"No health endpoint responded at {}",
self.base_url
)))
}
pub async fn chat_completion_stream(
&self,
request: &ChatRequest,
) -> Result<StreamedChatResponse, LlmClientError> {
let url = format!("{}/v1/chat/completions", self.base_url);
let stream_request = ChatRequest {
model: if request.model.is_empty() {
self.model.clone()
} else {
request.model.clone()
},
messages: request.messages.clone(),
temperature: request.temperature,
max_tokens: request.max_tokens,
stream: Some(true),
};
let start = Instant::now();
let resp = self.client.post(&url).json(&stream_request).send().await?;
let status = resp.status();
if !status.is_success() {
let body = resp.text().await.unwrap_or_default();
return Err(LlmClientError::ApiError {
status: status.as_u16(),
body,
});
}
let mut content = String::new();
let mut token_timestamps = Vec::new();
let mut ttft = None;
let mut final_usage = None;
let mut finish_reason = None;
let mut resp = resp;
let mut buffer = String::new();
let mut done = false;
while !done {
match resp.chunk().await? {
Some(chunk_bytes) => {
buffer.push_str(&String::from_utf8_lossy(&chunk_bytes));
}
None => {
done = true;
}
}
while let Some(newline_pos) = buffer.find('\n') {
let line: String = buffer[..newline_pos].trim().to_string();
buffer = buffer[newline_pos + 1..].to_string();
if line == "data: [DONE]" {
done = true;
break;
}
if let Some(json_str) = line.strip_prefix("data: ") {
if let Ok(sse_chunk) = serde_json::from_str::<StreamChunk>(json_str) {
if let Some(choice) = sse_chunk.choices.first() {
if let Some(ref c) = choice.delta.content {
if !c.is_empty() {
let now = start.elapsed();
if ttft.is_none() {
ttft = Some(now);
}
token_timestamps.push(now);
content.push_str(c);
}
}
if choice.finish_reason.is_some() {
finish_reason = choice.finish_reason.clone();
}
}
if sse_chunk.usage.is_some() {
final_usage = sse_chunk.usage;
}
}
}
}
}
let latency = start.elapsed();
Ok(StreamedChatResponse {
content,
latency,
ttft: ttft.unwrap_or(latency),
token_timestamps,
usage: final_usage,
finish_reason,
})
}
pub async fn wait_ready(
&self,
timeout: Duration,
poll_interval: Duration,
) -> Result<Duration, LlmClientError> {
let start = Instant::now();
loop {
if start.elapsed() > timeout {
return Err(LlmClientError::HealthCheckTimeout(timeout));
}
if self.health_check().await.is_ok() {
return Ok(start.elapsed());
}
tokio::time::sleep(poll_interval).await;
}
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
#[cfg(feature = "llm")]
#[test]
fn test_client_creation() {
let client = LlmClient::new("http://localhost:8081", "qwen-coder");
assert_eq!(client.base_url(), "http://localhost:8081");
assert_eq!(client.model(), "qwen-coder");
}
#[cfg(feature = "llm")]
#[test]
fn test_client_strips_trailing_slash() {
let client = LlmClient::new("http://localhost:8081/", "model");
assert_eq!(client.base_url(), "http://localhost:8081");
}
#[test]
fn test_chat_message_serialization() {
let msg = ChatMessage {
role: Role::User,
content: "Hello".to_string(),
};
let json = serde_json::to_string(&msg).unwrap();
assert!(json.contains("\"role\":\"user\""));
assert!(json.contains("\"content\":\"Hello\""));
}
#[test]
fn test_chat_request_serialization() {
let req = ChatRequest {
model: "test".to_string(),
messages: vec![ChatMessage {
role: Role::User,
content: "Hi".to_string(),
}],
temperature: Some(0.0),
max_tokens: Some(32),
stream: None,
};
let json = serde_json::to_string(&req).unwrap();
assert!(json.contains("\"temperature\":0.0"));
assert!(json.contains("\"max_tokens\":32"));
assert!(!json.contains("stream"));
}
#[test]
fn test_chat_request_omits_none_fields() {
let req = ChatRequest {
model: "test".to_string(),
messages: vec![],
temperature: None,
max_tokens: None,
stream: None,
};
let json = serde_json::to_string(&req).unwrap();
assert!(!json.contains("temperature"));
assert!(!json.contains("max_tokens"));
assert!(!json.contains("stream"));
}
#[test]
fn test_chat_response_deserialization() {
let json = r#"{
"id": "chatcmpl-123",
"object": "chat.completion",
"created": 1700000000,
"model": "qwen-coder",
"choices": [{
"index": 0,
"message": {"role": "assistant", "content": "Hello!"},
"finish_reason": "stop"
}],
"usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}
}"#;
let resp: ChatResponse = serde_json::from_str(json).unwrap();
assert_eq!(resp.id, "chatcmpl-123");
assert_eq!(resp.choices.len(), 1);
assert_eq!(resp.choices[0].message.content, "Hello!");
let usage = resp.usage.unwrap();
assert_eq!(usage.total_tokens, 15);
}
#[test]
fn test_apr_response_deserialization() {
let json = r#"{"_apr_metrics":{"latency_ms":1978,"tok_per_sec":4.14},"choices":[{"finish_reason":"stop","index":0,"message":{"content":"hello","role":"assistant"}}],"created":1772386202,"id":"chatcmpl-123","model":"test","object":"chat.completion","usage":{"completion_tokens":8,"prompt_tokens":9,"total_tokens":17}}"#;
let resp: ChatResponse = serde_json::from_str(json).unwrap();
assert_eq!(resp.choices[0].message.content, "hello");
}
#[test]
fn test_gguf_response_with_name_null() {
let json = r#"{"id":"chatcmpl-q4k-123","object":"chat.completion","created":1772385841,"model":"qwen","choices":[{"index":0,"message":{"role":"assistant","content":"4","name":null},"finish_reason":"stop"}],"usage":{"prompt_tokens":24,"completion_tokens":1,"total_tokens":25}}"#;
let resp: ChatResponse = serde_json::from_str(json).unwrap();
assert_eq!(resp.choices[0].message.content, "4");
}
#[test]
fn test_chat_response_without_usage() {
let json = r#"{
"id": "abc",
"object": "chat.completion",
"created": 0,
"model": "m",
"choices": []
}"#;
let resp: ChatResponse = serde_json::from_str(json).unwrap();
assert!(resp.usage.is_none());
assert!(resp.choices.is_empty());
}
#[test]
fn test_role_serialization_roundtrip() {
for (role, expected) in [
(Role::System, "\"system\""),
(Role::User, "\"user\""),
(Role::Assistant, "\"assistant\""),
] {
let json = serde_json::to_string(&role).unwrap();
assert_eq!(json, expected);
let back: Role = serde_json::from_str(&json).unwrap();
assert_eq!(back, role);
}
}
#[test]
fn test_usage_default() {
let usage = Usage::default();
assert_eq!(usage.prompt_tokens, 0);
assert_eq!(usage.completion_tokens, 0);
assert_eq!(usage.total_tokens, 0);
}
#[cfg(feature = "llm")]
#[test]
fn test_client_with_custom_client() {
let http = reqwest::Client::builder()
.timeout(Duration::from_secs(10))
.build()
.unwrap();
let client = LlmClient::with_client("http://example.com", "model", http);
assert_eq!(client.base_url(), "http://example.com");
}
#[cfg(feature = "llm")]
#[test]
fn test_health_check_timeout_error_display() {
let err = LlmClientError::HealthCheckTimeout(Duration::from_secs(30));
let msg = err.to_string();
assert!(msg.contains("30"));
assert!(msg.contains("timed out"));
}
#[test]
fn test_stream_chunk_deserialization() {
let json = r#"{"choices":[{"delta":{"content":"Hello"},"finish_reason":null}]}"#;
let chunk: StreamChunk = serde_json::from_str(json).unwrap();
assert_eq!(chunk.choices.len(), 1);
assert_eq!(chunk.choices[0].delta.content.as_deref(), Some("Hello"));
assert!(chunk.choices[0].finish_reason.is_none());
assert!(chunk.usage.is_none());
}
#[test]
fn test_stream_chunk_final_with_usage() {
let json = r#"{"choices":[{"delta":{},"finish_reason":"stop"}],"usage":{"prompt_tokens":10,"completion_tokens":5,"total_tokens":15}}"#;
let chunk: StreamChunk = serde_json::from_str(json).unwrap();
assert_eq!(chunk.choices[0].finish_reason.as_deref(), Some("stop"));
assert!(chunk.choices[0].delta.content.is_none());
let usage = chunk.usage.unwrap();
assert_eq!(usage.completion_tokens, 5);
}
#[test]
fn test_stream_chunk_empty_content() {
let json = r#"{"choices":[{"delta":{"content":""},"finish_reason":null}]}"#;
let chunk: StreamChunk = serde_json::from_str(json).unwrap();
assert_eq!(chunk.choices[0].delta.content.as_deref(), Some(""));
}
#[test]
fn test_chat_request_with_stream_true() {
let req = ChatRequest {
model: "test".to_string(),
messages: vec![],
temperature: None,
max_tokens: None,
stream: Some(true),
};
let json = serde_json::to_string(&req).unwrap();
assert!(json.contains("\"stream\":true"));
}
}