use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Instant;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use swarm_engine_core::learn::lora::EndpointResolver;
use swarm_engine_core::types::LoraConfig;
use crate::debug_channel::{LlmDebugChannel, LlmDebugEvent};
use crate::decider::{DecisionResponse, LlmDecider, LlmError, WorkerDecisionRequest};
use crate::prompt_builder::PromptBuilder;
use crate::response_parser;
#[derive(Debug, Clone)]
pub struct LlamaCppServerConfig {
pub endpoint: String,
pub model_name: String,
pub max_tokens: usize,
pub temperature: f32,
pub top_p: f32,
pub timeout_secs: u64,
pub chat_template: Option<ChatTemplate>,
}
#[derive(Debug, Clone)]
pub enum ChatTemplate {
Lfm2,
Qwen,
Llama3,
Custom {
user_prefix: String,
user_suffix: String,
assistant_prefix: String,
},
}
impl ChatTemplate {
pub fn format(&self, prompt: &str) -> String {
match self {
ChatTemplate::Lfm2 => {
format!("<|user|>\n{}\n<|assistant|>\n", prompt)
}
ChatTemplate::Qwen => {
format!(
"<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
prompt
)
}
ChatTemplate::Llama3 => {
format!("<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", prompt)
}
ChatTemplate::Custom {
user_prefix,
user_suffix,
assistant_prefix,
} => {
format!(
"{}{}{}{}",
user_prefix, prompt, user_suffix, assistant_prefix
)
}
}
}
pub fn stop_tokens(&self) -> &'static [&'static str] {
match self {
ChatTemplate::Lfm2 => &["<|user|>", "<|endoftext|>"],
ChatTemplate::Qwen => &["<|im_end|>", "<|im_start|>", "<|endoftext|>"],
ChatTemplate::Llama3 => &["<|eot_id|>", "<|start_header_id|>"],
ChatTemplate::Custom { .. } => &[], }
}
}
impl Default for LlamaCppServerConfig {
fn default() -> Self {
Self {
endpoint: "http://localhost:8080".to_string(),
model_name: "llama-server".to_string(),
max_tokens: 256,
temperature: 0.7,
top_p: 0.9,
timeout_secs: 30,
chat_template: Some(ChatTemplate::Lfm2), }
}
}
impl LlamaCppServerConfig {
pub fn new(endpoint: impl Into<String>) -> Self {
Self {
endpoint: endpoint.into(),
..Default::default()
}
}
pub fn with_model_name(mut self, name: impl Into<String>) -> Self {
self.model_name = name.into();
self
}
pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
self.max_tokens = max_tokens;
self
}
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.temperature = temperature;
self
}
pub fn with_top_p(mut self, top_p: f32) -> Self {
self.top_p = top_p;
self
}
pub fn with_timeout(mut self, secs: u64) -> Self {
self.timeout_secs = secs;
self
}
pub fn with_chat_template(mut self, template: ChatTemplate) -> Self {
self.chat_template = Some(template);
self
}
pub fn without_chat_template(mut self) -> Self {
self.chat_template = None;
self
}
}
#[derive(Debug, Serialize)]
struct LoraAdapterRequest {
id: u32,
scale: f32,
}
impl From<&LoraConfig> for LoraAdapterRequest {
fn from(config: &LoraConfig) -> Self {
Self {
id: config.id,
scale: config.scale,
}
}
}
#[derive(Debug, Serialize)]
struct CompletionRequest {
prompt: String,
n_predict: usize,
temperature: f32,
top_p: f32,
stream: bool,
#[serde(skip_serializing_if = "Vec::is_empty")]
stop: Vec<String>,
#[serde(skip_serializing_if = "Vec::is_empty")]
lora: Vec<LoraAdapterRequest>,
}
#[derive(Debug, Deserialize)]
struct CompletionResponse {
content: String,
#[serde(default)]
_stopped_eos: bool,
}
#[derive(Debug, Deserialize)]
struct HealthResponse {
status: String,
}
pub struct LlamaCppServerDecider {
config: LlamaCppServerConfig,
client: Arc<Client>,
prompt_builder: PromptBuilder,
endpoint_resolver: Option<Arc<dyn EndpointResolver>>,
}
impl Clone for LlamaCppServerDecider {
fn clone(&self) -> Self {
Self {
config: self.config.clone(),
client: Arc::clone(&self.client),
prompt_builder: self.prompt_builder.clone(),
endpoint_resolver: self.endpoint_resolver.clone(),
}
}
}
impl LlamaCppServerDecider {
pub fn new(config: LlamaCppServerConfig) -> Result<Self, LlmError> {
let client = Client::builder()
.timeout(std::time::Duration::from_secs(config.timeout_secs))
.build()
.map_err(|e| LlmError::permanent(format!("Failed to create HTTP client: {}", e)))?;
Ok(Self {
config,
client: Arc::new(client),
prompt_builder: PromptBuilder::new(),
endpoint_resolver: None,
})
}
pub fn with_endpoint_resolver(mut self, resolver: Arc<dyn EndpointResolver>) -> Self {
self.endpoint_resolver = Some(resolver);
self
}
fn current_endpoint(&self) -> String {
if let Some(ref resolver) = self.endpoint_resolver {
resolver.current_endpoint()
} else {
self.config.endpoint.clone()
}
}
async fn call_server(
&self,
prompt: &str,
lora: Option<&LoraConfig>,
) -> Result<(String, String, u64), LlmError> {
let start = Instant::now();
let (formatted_prompt, stop_tokens) = if let Some(ref template) = self.config.chat_template
{
let stop = template
.stop_tokens()
.iter()
.map(|s| s.to_string())
.collect();
(template.format(prompt), stop)
} else {
(prompt.to_string(), vec![])
};
let lora_adapters: Vec<LoraAdapterRequest> = lora
.map(|l| vec![LoraAdapterRequest::from(l)])
.unwrap_or_default();
let request = CompletionRequest {
prompt: formatted_prompt.clone(),
n_predict: self.config.max_tokens,
temperature: self.config.temperature,
top_p: self.config.top_p,
stream: false,
stop: stop_tokens,
lora: lora_adapters,
};
let endpoint = self.current_endpoint();
let url = format!("{}/completion", endpoint);
let response = self
.client
.post(&url)
.json(&request)
.send()
.await
.map_err(|e| {
if e.is_timeout() {
LlmError::transient(format!("Request timeout: {}", e))
} else if e.is_connect() {
LlmError::transient(format!("Connection error: {}", e))
} else {
LlmError::permanent(format!("HTTP error: {}", e))
}
})?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
return Err(LlmError::permanent(format!(
"Server error {}: {}",
status, body
)));
}
let completion: CompletionResponse = response
.json()
.await
.map_err(|e| LlmError::permanent(format!("Failed to parse response: {}", e)))?;
let latency_ms = start.elapsed().as_millis() as u64;
Ok((completion.content, formatted_prompt, latency_ms))
}
fn emit_debug_event(&self, event: LlmDebugEvent) {
LlmDebugChannel::global().emit(event);
}
}
impl LlmDecider for LlamaCppServerDecider {
fn decide(
&self,
request: WorkerDecisionRequest,
) -> Pin<Box<dyn Future<Output = Result<DecisionResponse, LlmError>> + Send + '_>> {
let current_endpoint = self.current_endpoint();
Box::pin(async move {
let prompt = self.prompt_builder.build(&request.context);
let worker_id = request.worker_id.0;
let lora = request.lora.as_ref();
let (raw_response, _formatted_prompt, latency_ms) =
match self.call_server(&prompt, lora).await {
Ok(result) => result,
Err(e) => {
self.emit_debug_event(
LlmDebugEvent::new("decide", &self.config.model_name)
.worker_id(worker_id)
.endpoint(¤t_endpoint)
.prompt(&prompt)
.lora_opt(request.lora.clone())
.error(e.message()),
);
return Err(e);
}
};
let candidate_names = response_parser::candidate_names(&request.context.candidates);
match response_parser::parse_response(&raw_response, &candidate_names) {
Ok(mut d) => {
self.emit_debug_event(
LlmDebugEvent::new("decide", &self.config.model_name)
.worker_id(worker_id)
.endpoint(¤t_endpoint)
.prompt(&prompt)
.response(&raw_response)
.lora_opt(request.lora.clone())
.latency_ms(latency_ms),
);
d.prompt = Some(prompt);
d.raw_response = Some(raw_response);
Ok(d)
}
Err(e) => {
self.emit_debug_event(
LlmDebugEvent::new("decide", &self.config.model_name)
.worker_id(worker_id)
.endpoint(¤t_endpoint)
.prompt(&prompt)
.response(&raw_response)
.lora_opt(request.lora.clone())
.error(e.message())
.latency_ms(latency_ms),
);
tracing::warn!(error = %e, "Parse error");
tracing::debug!(raw = %raw_response, "Raw response");
Err(e)
}
}
})
}
fn call_raw(
&self,
prompt: &str,
lora: Option<&LoraConfig>,
) -> Pin<Box<dyn Future<Output = Result<String, LlmError>> + Send + '_>> {
let prompt = prompt.to_string();
let lora_owned = lora.cloned();
let current_endpoint = self.current_endpoint();
Box::pin(async move {
match self.call_server(&prompt, lora_owned.as_ref()).await {
Ok((response, _formatted_prompt, latency_ms)) => {
self.emit_debug_event(
LlmDebugEvent::new("call_raw", &self.config.model_name)
.endpoint(¤t_endpoint)
.prompt(&prompt)
.response(&response)
.lora_opt(lora_owned.clone())
.latency_ms(latency_ms),
);
Ok(response)
}
Err(e) => {
self.emit_debug_event(
LlmDebugEvent::new("call_raw", &self.config.model_name)
.endpoint(¤t_endpoint)
.prompt(&prompt)
.lora_opt(lora_owned)
.error(e.message()),
);
Err(e)
}
}
})
}
fn model_name(&self) -> &str {
&self.config.model_name
}
fn endpoint(&self) -> &str {
&self.config.endpoint
}
fn is_healthy(&self) -> Pin<Box<dyn Future<Output = bool> + Send + '_>> {
let client = Arc::clone(&self.client);
let endpoint = self.current_endpoint();
Box::pin(async move {
let url = format!("{}/health", endpoint);
match client.get(&url).send().await {
Ok(response) => {
if let Ok(health) = response.json::<HealthResponse>().await {
health.status == "ok"
} else {
false
}
}
Err(_) => false,
}
})
}
fn max_concurrency(&self) -> Pin<Box<dyn Future<Output = Option<usize>> + Send + '_>> {
let client = Arc::clone(&self.client);
let endpoint = self.current_endpoint();
Box::pin(async move {
let url = format!("{}/slots", endpoint);
match client.get(&url).send().await {
Ok(response) => {
if let Ok(slots) = response.json::<Vec<serde_json::Value>>().await {
Some(slots.len())
} else {
None
}
}
Err(_) => None,
}
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_default() {
let config = LlamaCppServerConfig::default();
assert_eq!(config.endpoint, "http://localhost:8080");
assert_eq!(config.max_tokens, 256);
assert!(matches!(config.chat_template, Some(ChatTemplate::Lfm2)));
}
#[test]
fn test_config_builder() {
let config = LlamaCppServerConfig::new("http://192.168.1.100:9000")
.with_model_name("my-model")
.with_max_tokens(512)
.with_temperature(0.5)
.with_top_p(0.95)
.with_timeout(60);
assert_eq!(config.endpoint, "http://192.168.1.100:9000");
assert_eq!(config.model_name, "my-model");
assert_eq!(config.max_tokens, 512);
assert!((config.temperature - 0.5).abs() < f32::EPSILON);
assert!((config.top_p - 0.95).abs() < f32::EPSILON);
assert_eq!(config.timeout_secs, 60);
}
#[test]
fn test_config_chat_template() {
let config = LlamaCppServerConfig::default().with_chat_template(ChatTemplate::Qwen);
assert!(matches!(config.chat_template, Some(ChatTemplate::Qwen)));
let config = LlamaCppServerConfig::default().without_chat_template();
assert!(config.chat_template.is_none());
}
#[test]
fn test_chat_template_lfm2() {
let template = ChatTemplate::Lfm2;
let formatted = template.format("Hello");
assert_eq!(formatted, "<|user|>\nHello\n<|assistant|>\n");
}
#[test]
fn test_chat_template_qwen() {
let template = ChatTemplate::Qwen;
let formatted = template.format("Hello");
assert!(formatted.contains("<|im_start|>user"));
assert!(formatted.contains("<|im_end|>"));
assert!(formatted.contains("<|im_start|>assistant"));
}
#[test]
fn test_chat_template_llama3() {
let template = ChatTemplate::Llama3;
let formatted = template.format("Hello");
assert!(formatted.contains("<|start_header_id|>user"));
assert!(formatted.contains("<|eot_id|>"));
}
#[test]
fn test_chat_template_custom() {
let template = ChatTemplate::Custom {
user_prefix: "[USER]".to_string(),
user_suffix: "[/USER]".to_string(),
assistant_prefix: "[ASSISTANT]".to_string(),
};
let formatted = template.format("Hello");
assert_eq!(formatted, "[USER]Hello[/USER][ASSISTANT]");
}
#[test]
fn test_chat_template_stop_tokens() {
let lfm2 = ChatTemplate::Lfm2;
let stop = lfm2.stop_tokens();
assert!(stop.contains(&"<|user|>"));
assert!(stop.contains(&"<|endoftext|>"));
let qwen = ChatTemplate::Qwen;
let stop = qwen.stop_tokens();
assert!(stop.contains(&"<|im_end|>"));
let custom = ChatTemplate::Custom {
user_prefix: "[U]".to_string(),
user_suffix: "[/U]".to_string(),
assistant_prefix: "[A]".to_string(),
};
assert!(custom.stop_tokens().is_empty());
}
}