use serde::{Deserialize, Serialize};
use std::sync::{OnceLock, RwLock, RwLockReadGuard, RwLockWriteGuard};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum LlmBackend {
#[default]
Gateway,
Direct,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LlmClientConfig {
pub backend: LlmBackend,
pub gateway_url: String,
pub api_key: Option<String>,
pub default_model: Option<String>,
pub timeout_ms: u32,
pub direct_provider: Option<String>,
pub debug: bool,
}
impl Default for LlmClientConfig {
fn default() -> Self {
Self {
backend: LlmBackend::Gateway,
gateway_url: default_gateway_url(),
api_key: None,
default_model: Some("gpt-4o-mini".to_string()),
timeout_ms: 60000,
direct_provider: None,
debug: false,
}
}
}
static GATEWAY_URL_OVERRIDE: OnceLock<RwLock<Option<String>>> = OnceLock::new();
fn gateway_url_override() -> &'static RwLock<Option<String>> {
GATEWAY_URL_OVERRIDE.get_or_init(|| RwLock::new(None))
}
fn read_gateway_url_override() -> RwLockReadGuard<'static, Option<String>> {
gateway_url_override()
.read()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
fn write_gateway_url_override() -> RwLockWriteGuard<'static, Option<String>> {
gateway_url_override()
.write()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
pub fn set_gateway_url(gateway_url: impl Into<String>) {
let gateway_url = gateway_url.into().trim().to_string();
if gateway_url.is_empty() {
return;
}
*write_gateway_url_override() = Some(gateway_url);
}
#[cfg(test)]
fn clear_gateway_url_override() {
*write_gateway_url_override() = None;
}
pub fn default_gateway_url() -> String {
if let Some(url) = read_gateway_url_override().clone() {
return url;
}
if let Ok(url) = std::env::var("XYBRID_GATEWAY_URL") {
return url;
}
if let Ok(url) = std::env::var("XYBRID_PLATFORM_URL") {
return format!("{}/v1", url.trim_end_matches('/'));
}
"https://api.xybrid.dev/v1".to_string()
}
pub use xybrid_core::ir::MessageRole;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage {
pub role: MessageRole,
pub content: String,
}
impl ChatMessage {
pub fn system(content: String) -> Self {
Self {
role: MessageRole::System,
content,
}
}
pub fn user(content: String) -> Self {
Self {
role: MessageRole::User,
content,
}
}
pub fn assistant(content: String) -> Self {
Self {
role: MessageRole::Assistant,
content,
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct CompletionRequest {
pub model: Option<String>,
pub prompt: Option<String>,
pub messages: Option<Vec<ChatMessage>>,
pub system: Option<String>,
pub max_tokens: Option<u32>,
pub temperature: Option<f32>,
pub top_p: Option<f32>,
pub stop: Option<Vec<String>>,
}
impl CompletionRequest {
pub fn new(prompt: String) -> Self {
Self {
prompt: Some(prompt),
..Default::default()
}
}
pub fn chat(messages: Vec<ChatMessage>) -> Self {
Self {
messages: Some(messages),
..Default::default()
}
}
pub fn with_model(mut self, model: String) -> Self {
self.model = Some(model);
self
}
pub fn with_system(mut self, system: String) -> Self {
self.system = Some(system);
self
}
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = Some(max_tokens);
self
}
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.temperature = Some(temperature);
self
}
pub fn with_top_p(mut self, top_p: f32) -> Self {
self.top_p = Some(top_p);
self
}
pub fn with_stop(mut self, stop: Vec<String>) -> Self {
self.stop = Some(stop);
self
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct TokenUsage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompletionResponse {
pub success: bool,
pub error: Option<String>,
pub text: String,
pub model: String,
pub finish_reason: Option<String>,
pub usage: Option<TokenUsage>,
pub latency_ms: Option<u32>,
pub backend: Option<String>,
}
impl CompletionResponse {
pub fn success(
text: String,
model: String,
finish_reason: Option<String>,
usage: Option<TokenUsage>,
latency_ms: Option<u32>,
backend: Option<String>,
) -> Self {
Self {
success: true,
error: None,
text,
model,
finish_reason,
usage,
latency_ms,
backend,
}
}
pub fn error(message: String) -> Self {
Self {
success: false,
error: Some(message),
text: String::new(),
model: String::new(),
finish_reason: None,
usage: None,
latency_ms: None,
backend: None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_llm_backend_default() {
assert_eq!(LlmBackend::default(), LlmBackend::Gateway);
}
#[test]
fn test_llm_client_config_default() {
let config = LlmClientConfig::default();
assert_eq!(config.backend, LlmBackend::Gateway);
assert_eq!(config.timeout_ms, 60000);
assert_eq!(config.default_model, Some("gpt-4o-mini".to_string()));
assert!(config.api_key.is_none());
assert!(config.direct_provider.is_none());
assert!(!config.debug);
}
#[test]
fn test_default_gateway_url_fallback() {
let url = default_gateway_url();
assert!(
url.ends_with("/v1"),
"gateway_url should end with '/v1', got: {}",
url
);
}
#[test]
fn test_process_local_gateway_url_override() {
clear_gateway_url_override();
set_gateway_url("https://local.gateway.test/v1");
assert_eq!(default_gateway_url(), "https://local.gateway.test/v1");
clear_gateway_url_override();
}
#[test]
fn test_llm_backend_serialization() {
let gateway = LlmBackend::Gateway;
let direct = LlmBackend::Direct;
let gateway_json = serde_json::to_string(&gateway).unwrap();
let direct_json = serde_json::to_string(&direct).unwrap();
assert_eq!(gateway_json, "\"Gateway\"");
assert_eq!(direct_json, "\"Direct\"");
let gateway_parsed: LlmBackend = serde_json::from_str(&gateway_json).unwrap();
let direct_parsed: LlmBackend = serde_json::from_str(&direct_json).unwrap();
assert_eq!(gateway_parsed, LlmBackend::Gateway);
assert_eq!(direct_parsed, LlmBackend::Direct);
}
#[test]
fn test_llm_client_config_serialization() {
let config = LlmClientConfig {
backend: LlmBackend::Direct,
gateway_url: "https://test.example.com/v1".to_string(),
api_key: Some("test-key".to_string()),
default_model: Some("gpt-4".to_string()),
timeout_ms: 30000,
direct_provider: Some("openai".to_string()),
debug: true,
};
let json = serde_json::to_string(&config).unwrap();
let parsed: LlmClientConfig = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.backend, LlmBackend::Direct);
assert_eq!(parsed.gateway_url, "https://test.example.com/v1");
assert_eq!(parsed.api_key, Some("test-key".to_string()));
assert_eq!(parsed.default_model, Some("gpt-4".to_string()));
assert_eq!(parsed.timeout_ms, 30000);
assert_eq!(parsed.direct_provider, Some("openai".to_string()));
assert!(parsed.debug);
}
#[test]
fn test_chat_message_system() {
let msg = ChatMessage::system("You are helpful".to_string());
assert_eq!(msg.role, MessageRole::System);
assert_eq!(msg.content, "You are helpful");
}
#[test]
fn test_chat_message_user() {
let msg = ChatMessage::user("Hello!".to_string());
assert_eq!(msg.role, MessageRole::User);
assert_eq!(msg.content, "Hello!");
}
#[test]
fn test_chat_message_assistant() {
let msg = ChatMessage::assistant("Hi there!".to_string());
assert_eq!(msg.role, MessageRole::Assistant);
assert_eq!(msg.content, "Hi there!");
}
#[test]
fn test_message_role_serialization() {
let system = MessageRole::System;
let user = MessageRole::User;
let assistant = MessageRole::Assistant;
let system_json = serde_json::to_string(&system).unwrap();
let user_json = serde_json::to_string(&user).unwrap();
let assistant_json = serde_json::to_string(&assistant).unwrap();
assert_eq!(system_json, "\"system\"");
assert_eq!(user_json, "\"user\"");
assert_eq!(assistant_json, "\"assistant\"");
let system_parsed: MessageRole = serde_json::from_str(&system_json).unwrap();
let user_parsed: MessageRole = serde_json::from_str(&user_json).unwrap();
let assistant_parsed: MessageRole = serde_json::from_str(&assistant_json).unwrap();
assert_eq!(system_parsed, MessageRole::System);
assert_eq!(user_parsed, MessageRole::User);
assert_eq!(assistant_parsed, MessageRole::Assistant);
}
#[test]
fn test_chat_message_serialization() {
let msg = ChatMessage::user("Hello!".to_string());
let json = serde_json::to_string(&msg).unwrap();
let parsed: ChatMessage = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.role, MessageRole::User);
assert_eq!(parsed.content, "Hello!");
}
#[test]
fn test_completion_request_new() {
let request = CompletionRequest::new("Hello".to_string());
assert_eq!(request.prompt, Some("Hello".to_string()));
assert!(request.messages.is_none());
assert!(request.model.is_none());
}
#[test]
fn test_completion_request_chat() {
let messages = vec![
ChatMessage::system("Be helpful".to_string()),
ChatMessage::user("Hello".to_string()),
];
let request = CompletionRequest::chat(messages);
assert!(request.prompt.is_none());
assert!(request.messages.is_some());
let msgs = request.messages.unwrap();
assert_eq!(msgs.len(), 2);
assert_eq!(msgs[0].role, MessageRole::System);
assert_eq!(msgs[1].role, MessageRole::User);
}
#[test]
fn test_completion_request_builder() {
let request = CompletionRequest::new("Hello".to_string())
.with_model("gpt-4".to_string())
.with_system("Be helpful".to_string())
.with_max_tokens(100)
.with_temperature(0.7)
.with_top_p(0.9)
.with_stop(vec!["END".to_string()]);
assert_eq!(request.prompt, Some("Hello".to_string()));
assert_eq!(request.model, Some("gpt-4".to_string()));
assert_eq!(request.system, Some("Be helpful".to_string()));
assert_eq!(request.max_tokens, Some(100));
assert_eq!(request.temperature, Some(0.7));
assert_eq!(request.top_p, Some(0.9));
assert_eq!(request.stop, Some(vec!["END".to_string()]));
}
#[test]
fn test_completion_request_serialization() {
let request = CompletionRequest::new("Hello".to_string())
.with_model("gpt-4".to_string())
.with_max_tokens(100);
let json = serde_json::to_string(&request).unwrap();
let parsed: CompletionRequest = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.prompt, Some("Hello".to_string()));
assert_eq!(parsed.model, Some("gpt-4".to_string()));
assert_eq!(parsed.max_tokens, Some(100));
}
#[test]
fn test_completion_request_with_messages_serialization() {
let messages = vec![
ChatMessage::system("System message".to_string()),
ChatMessage::user("User message".to_string()),
ChatMessage::assistant("Assistant message".to_string()),
];
let request = CompletionRequest::chat(messages);
let json = serde_json::to_string(&request).unwrap();
let parsed: CompletionRequest = serde_json::from_str(&json).unwrap();
let msgs = parsed.messages.unwrap();
assert_eq!(msgs.len(), 3);
assert_eq!(msgs[0].role, MessageRole::System);
assert_eq!(msgs[0].content, "System message");
assert_eq!(msgs[1].role, MessageRole::User);
assert_eq!(msgs[1].content, "User message");
assert_eq!(msgs[2].role, MessageRole::Assistant);
assert_eq!(msgs[2].content, "Assistant message");
}
#[test]
fn test_token_usage_default() {
let usage = TokenUsage::default();
assert_eq!(usage.prompt_tokens, 0);
assert_eq!(usage.completion_tokens, 0);
assert_eq!(usage.total_tokens, 0);
}
#[test]
fn test_token_usage_serialization() {
let usage = TokenUsage {
prompt_tokens: 10,
completion_tokens: 20,
total_tokens: 30,
};
let json = serde_json::to_string(&usage).unwrap();
let parsed: TokenUsage = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.prompt_tokens, 10);
assert_eq!(parsed.completion_tokens, 20);
assert_eq!(parsed.total_tokens, 30);
}
#[test]
fn test_completion_response_success() {
let usage = TokenUsage {
prompt_tokens: 5,
completion_tokens: 15,
total_tokens: 20,
};
let response = CompletionResponse::success(
"Hello, world!".to_string(),
"gpt-4".to_string(),
Some("stop".to_string()),
Some(usage),
Some(150),
Some("gateway".to_string()),
);
assert!(response.success);
assert!(response.error.is_none());
assert_eq!(response.text, "Hello, world!");
assert_eq!(response.model, "gpt-4");
assert_eq!(response.finish_reason, Some("stop".to_string()));
assert_eq!(response.latency_ms, Some(150));
assert_eq!(response.backend, Some("gateway".to_string()));
assert!(response.usage.is_some());
let u = response.usage.unwrap();
assert_eq!(u.prompt_tokens, 5);
assert_eq!(u.completion_tokens, 15);
assert_eq!(u.total_tokens, 20);
}
#[test]
fn test_completion_response_error() {
let response = CompletionResponse::error("Connection timeout".to_string());
assert!(!response.success);
assert_eq!(response.error, Some("Connection timeout".to_string()));
assert_eq!(response.text, "");
assert_eq!(response.model, "");
assert!(response.finish_reason.is_none());
assert!(response.usage.is_none());
assert!(response.latency_ms.is_none());
assert!(response.backend.is_none());
}
#[test]
fn test_completion_response_serialization() {
let response = CompletionResponse::success(
"Test response".to_string(),
"claude-3".to_string(),
Some("length".to_string()),
None,
Some(200),
None,
);
let json = serde_json::to_string(&response).unwrap();
let parsed: CompletionResponse = serde_json::from_str(&json).unwrap();
assert!(parsed.success);
assert_eq!(parsed.text, "Test response");
assert_eq!(parsed.model, "claude-3");
assert_eq!(parsed.finish_reason, Some("length".to_string()));
assert_eq!(parsed.latency_ms, Some(200));
}
#[test]
fn test_completion_response_error_serialization() {
let response = CompletionResponse::error("API rate limit exceeded".to_string());
let json = serde_json::to_string(&response).unwrap();
let parsed: CompletionResponse = serde_json::from_str(&json).unwrap();
assert!(!parsed.success);
assert_eq!(parsed.error, Some("API rate limit exceeded".to_string()));
assert_eq!(parsed.text, "");
}
}