use crate::ir::{Envelope, EnvelopeKind};
use crate::runtime_adapter::{
AdapterError, AdapterResult, ModelMetadata, RuntimeAdapter, RuntimeAdapterExt,
};
use std::collections::HashMap;
use std::path::Path;
pub use super::types::{
ChatMessage, GenerationConfig, LlmConfig, PartialToken, StreamingCallback, StreamingError,
};
pub type LlmResult<T> = Result<T, AdapterError>;
pub(crate) fn local_execution_provider(backend_name: &str) -> &'static str {
match backend_name {
"llama-cpp" => llamacpp_execution_provider(),
"mistral" => mistral_execution_provider(),
_ => "cpu",
}
}
#[cfg(all(feature = "llm-llamacpp", any(target_os = "macos", target_os = "ios")))]
fn llamacpp_execution_provider() -> &'static str {
"metal"
}
#[cfg(all(
feature = "llm-llamacpp",
not(any(target_os = "macos", target_os = "ios"))
))]
fn llamacpp_execution_provider() -> &'static str {
"cpu"
}
#[cfg(not(feature = "llm-llamacpp"))]
fn llamacpp_execution_provider() -> &'static str {
"cpu"
}
#[cfg(feature = "llm-mistral-metal")]
fn mistral_execution_provider() -> &'static str {
"metal"
}
#[cfg(all(feature = "llm-mistral-cuda", not(feature = "llm-mistral-metal")))]
fn mistral_execution_provider() -> &'static str {
"cuda"
}
#[cfg(not(any(feature = "llm-mistral-metal", feature = "llm-mistral-cuda")))]
fn mistral_execution_provider() -> &'static str {
"cpu"
}
#[derive(Debug, Clone)]
pub struct GenerationOutput {
pub text: String,
pub tokens_generated: usize,
pub generation_time_ms: u64,
pub tokens_per_second: f32,
pub finish_reason: String,
pub ttft_ms: Option<u64>,
pub mean_itl_ms: Option<f32>,
pub p95_itl_ms: Option<u32>,
pub emitted_chunks: Option<u32>,
pub inter_chunk_ms: Vec<u32>,
pub decode_tps: Option<f32>,
pub prefill_tps: Option<f32>,
}
pub trait LlmBackend: Send + Sync {
fn name(&self) -> &str;
fn supported_formats(&self) -> Vec<&'static str>;
fn load(&mut self, config: &LlmConfig) -> LlmResult<()>;
fn is_loaded(&self) -> bool;
fn unload(&mut self) -> LlmResult<()>;
fn generate(
&self,
messages: &[ChatMessage],
config: &GenerationConfig,
) -> LlmResult<GenerationOutput>;
fn generate_raw(&self, prompt: &str, config: &GenerationConfig) -> LlmResult<GenerationOutput>;
fn generate_streaming(
&self,
messages: &[ChatMessage],
config: &GenerationConfig,
on_token: StreamingCallback<'_>,
) -> LlmResult<GenerationOutput> {
let output = self.generate(messages, config)?;
let partial = PartialToken {
token: output.text.clone(),
token_id: None,
index: 0,
cumulative_text: output.text.clone(),
finish_reason: Some(output.finish_reason.clone()),
};
let mut callback = on_token;
callback(partial).map_err(AdapterError::from_streaming_callback_error)?;
Ok(output)
}
fn supports_streaming(&self) -> bool {
false
}
fn memory_usage(&self) -> Option<u64> {
None
}
fn context_length(&self) -> Option<usize> {
None
}
fn last_cached_prefix_len(&self) -> Option<usize> {
None
}
fn wire_label(&self) -> Option<&'static str> {
None
}
}
pub type BackendFactory = fn() -> LlmResult<Box<dyn LlmBackend>>;
pub struct LlmRuntimeAdapter {
backend: Box<dyn LlmBackend>,
metadata: Option<ModelMetadata>,
current_model_path: Option<String>,
default_generation_config: GenerationConfig,
}
impl LlmRuntimeAdapter {
#[cfg(feature = "llm-mistral")]
pub fn new() -> AdapterResult<Self> {
use crate::runtime_adapter::mistral::MistralBackend;
let backend = MistralBackend::new()?;
Ok(Self::with_backend(Box::new(backend)))
}
#[cfg(all(feature = "llm-llamacpp", not(feature = "llm-mistral")))]
pub fn new() -> AdapterResult<Self> {
use crate::runtime_adapter::llama_cpp::LlamaCppBackend;
let backend = LlamaCppBackend::new()?;
Ok(Self::with_backend(Box::new(backend)))
}
#[cfg(feature = "llm-mistral")]
pub fn with_mistral() -> AdapterResult<Self> {
use crate::runtime_adapter::mistral::MistralBackend;
let backend = MistralBackend::new()?;
Ok(Self::with_backend(Box::new(backend)))
}
#[cfg(feature = "llm-llamacpp")]
pub fn with_llamacpp() -> AdapterResult<Self> {
use crate::runtime_adapter::llama_cpp::LlamaCppBackend;
let backend = LlamaCppBackend::new()?;
Ok(Self::with_backend(Box::new(backend)))
}
#[cfg(any(feature = "llm-mistral", feature = "llm-llamacpp"))]
pub fn with_backend_hint(hint: Option<&str>) -> AdapterResult<Self> {
match hint {
#[cfg(feature = "llm-llamacpp")]
Some("llamacpp") => Self::with_llamacpp(),
#[cfg(feature = "llm-mistral")]
Some("mistral") => Self::with_mistral(),
_ => Self::new(),
}
}
pub fn with_backend(backend: Box<dyn LlmBackend>) -> Self {
Self {
backend,
metadata: None,
current_model_path: None,
default_generation_config: GenerationConfig::default(),
}
}
pub fn set_generation_config(&mut self, config: GenerationConfig) {
self.default_generation_config = config;
}
pub fn generation_config(&self) -> &GenerationConfig {
&self.default_generation_config
}
pub fn memory_usage(&self) -> Option<u64> {
self.backend.memory_usage()
}
pub fn context_length(&self) -> Option<usize> {
self.backend.context_length()
}
pub fn wire_label(&self) -> Option<&'static str> {
self.backend.wire_label()
}
pub fn backend(&self) -> &dyn LlmBackend {
self.backend.as_ref()
}
pub fn generate_with_config(
&self,
prompt: &str,
system: Option<&str>,
config: &GenerationConfig,
) -> AdapterResult<GenerationOutput> {
let mut messages = Vec::new();
if let Some(sys) = system {
messages.push(ChatMessage::system(sys));
}
messages.push(ChatMessage::user(prompt));
self.backend.generate(&messages, config)
}
fn extract_model_id(&self, path: &str) -> String {
Path::new(path)
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("unknown")
.to_string()
}
fn parse_generation_config(&self, metadata: &HashMap<String, String>) -> GenerationConfig {
let mut config = self.default_generation_config.clone();
if let Some(max_tokens) = metadata.get("max_tokens").and_then(|s| s.parse().ok()) {
config.max_tokens = max_tokens;
}
if let Some(temperature) = metadata.get("temperature").and_then(|s| s.parse().ok()) {
config.temperature = temperature;
}
if let Some(top_p) = metadata.get("top_p").and_then(|s| s.parse().ok()) {
config.top_p = top_p;
}
if let Some(min_p) = metadata.get("min_p").and_then(|s| s.parse().ok()) {
config.min_p = min_p;
}
if let Some(top_k) = metadata.get("top_k").and_then(|s| s.parse().ok()) {
config.top_k = top_k;
}
config
}
}
impl LlmRuntimeAdapter {
pub fn load_model_with_config(&mut self, config: &LlmConfig) -> AdapterResult<()> {
let path = &config.model_path;
let model_path = Path::new(path);
if !model_path.exists() {
return Err(AdapterError::ModelNotFound(path.to_string()));
}
self.backend.load(config)?;
let model_id = self.extract_model_id(path);
self.metadata = Some(ModelMetadata {
model_id: model_id.clone(),
version: "1.0.0".to_string(),
runtime_type: self.backend.name().to_string(),
model_path: path.to_string(),
input_schema: {
let mut schema = HashMap::new();
schema.insert("text".to_string(), vec![1]);
schema
},
output_schema: {
let mut schema = HashMap::new();
schema.insert("text".to_string(), vec![1]);
schema
},
});
self.current_model_path = Some(path.to_string());
Ok(())
}
}
impl RuntimeAdapter for LlmRuntimeAdapter {
fn name(&self) -> &str {
"llm"
}
fn supported_formats(&self) -> Vec<&'static str> {
self.backend.supported_formats()
}
fn load_model(&mut self, path: &str) -> AdapterResult<()> {
self.load_model_with_config(&LlmConfig::new(path))
}
fn execute(&self, input: &Envelope) -> AdapterResult<Envelope> {
if !self.backend.is_loaded() {
return Err(AdapterError::ModelNotLoaded(
"No model loaded. Call load_model() first.".to_string(),
));
}
match &input.kind {
EnvelopeKind::Text(prompt) => {
let system = input.metadata.get("system_prompt").map(|s| s.as_str());
let config = self.parse_generation_config(&input.metadata);
let mut messages = Vec::new();
if let Some(sys) = system {
messages.push(ChatMessage::system(sys));
}
messages.push(ChatMessage::user(prompt));
let output = self.backend.generate(&messages, &config)?;
let mut response_metadata = HashMap::new();
response_metadata.insert(
"tokens_generated".to_string(),
output.tokens_generated.to_string(),
);
response_metadata.insert(
"generation_time_ms".to_string(),
output.generation_time_ms.to_string(),
);
response_metadata.insert(
"tokens_per_second".to_string(),
format!("{:.2}", output.tokens_per_second),
);
response_metadata.insert("finish_reason".to_string(), output.finish_reason.clone());
if let Some(v) = output.ttft_ms {
let s = v.to_string();
response_metadata.insert("ttft_ms".to_string(), s.clone());
crate::tracing::add_metadata("ttft_ms", &s);
}
if let Some(v) = output.mean_itl_ms {
let s = format!("{:.4}", v);
response_metadata.insert("mean_itl_ms".to_string(), s.clone());
crate::tracing::add_metadata("mean_itl_ms", &s);
}
if let Some(v) = output.p95_itl_ms {
let s = v.to_string();
response_metadata.insert("p95_itl_ms".to_string(), s.clone());
crate::tracing::add_metadata("p95_itl_ms", &s);
}
if let Some(v) = output.emitted_chunks {
let s = v.to_string();
response_metadata.insert("emitted_chunks".to_string(), s.clone());
crate::tracing::add_metadata("emitted_chunks", &s);
}
if let Some(v) = output.decode_tps {
let s = format!("{:.4}", v);
response_metadata.insert("decode_tps".to_string(), s.clone());
crate::tracing::add_metadata("decode_tps", &s);
}
if let Some(v) = output.prefill_tps {
let s = format!("{:.4}", v);
response_metadata.insert("prefill_tps".to_string(), s.clone());
crate::tracing::add_metadata("prefill_tps", &s);
}
crate::tracing::add_metadata(
"tokens_generated",
output.tokens_generated.to_string(),
);
crate::tracing::add_metadata("tokens_out", output.tokens_generated.to_string());
crate::tracing::add_metadata(
"generation_time_ms",
output.generation_time_ms.to_string(),
);
crate::tracing::add_metadata(
"tokens_per_second",
format!("{:.2}", output.tokens_per_second),
);
crate::tracing::add_metadata("finish_reason", &output.finish_reason);
#[cfg(all(
any(feature = "llm-mistral-metal", feature = "llm-llamacpp"),
target_os = "macos"
))]
crate::tracing::add_metadata("span_kind", "gpu");
#[cfg(not(all(
any(feature = "llm-mistral-metal", feature = "llm-llamacpp"),
target_os = "macos"
)))]
crate::tracing::add_metadata("span_kind", "cpu");
crate::tracing::add_metadata(
"execution_provider",
local_execution_provider(self.backend.name()),
);
if let Some(n) = self.backend.last_cached_prefix_len() {
if n > 0 {
crate::tracing::add_metadata("prompt_cached_tokens", n.to_string());
}
}
Ok(Envelope {
kind: EnvelopeKind::Text(output.text),
metadata: response_metadata,
})
}
EnvelopeKind::Audio(_) => Err(AdapterError::InvalidInput(
"LLM adapter expects Text input, not Audio".to_string(),
)),
EnvelopeKind::Embedding(_) => Err(AdapterError::InvalidInput(
"LLM adapter expects Text input, not Embedding".to_string(),
)),
}
}
}
impl RuntimeAdapterExt for LlmRuntimeAdapter {
fn is_loaded(&self, _model_id: &str) -> bool {
self.backend.is_loaded()
}
fn get_metadata(&self, _model_id: &str) -> AdapterResult<&ModelMetadata> {
self.metadata
.as_ref()
.ok_or_else(|| AdapterError::ModelNotLoaded("No model loaded".to_string()))
}
fn infer(&self, _model_id: &str, input: &Envelope) -> AdapterResult<Envelope> {
self.execute(input)
}
fn unload_model(&mut self, _model_id: &str) -> AdapterResult<()> {
self.backend.unload()?;
self.metadata = None;
self.current_model_path = None;
Ok(())
}
fn list_loaded_models(&self) -> Vec<String> {
if self.backend.is_loaded() {
self.metadata
.as_ref()
.map(|m| vec![m.model_id.clone()])
.unwrap_or_default()
} else {
Vec::new()
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn local_execution_provider_maps_known_backends() {
let llama = local_execution_provider("llama-cpp");
assert!(!llama.is_empty(), "llama-cpp must map to a label");
let mistral = local_execution_provider("mistral");
assert!(!mistral.is_empty(), "mistral must map to a label");
}
#[test]
fn local_execution_provider_unknown_backend_falls_back_to_cpu() {
assert_eq!(local_execution_provider("mock"), "cpu");
assert_eq!(local_execution_provider(""), "cpu");
}
#[test]
fn test_generation_output_structure() {
let output = GenerationOutput {
text: "Hello world".to_string(),
tokens_generated: 3,
generation_time_ms: 100,
tokens_per_second: 30.0,
finish_reason: "stop".to_string(),
ttft_ms: None,
mean_itl_ms: None,
p95_itl_ms: None,
emitted_chunks: None,
inter_chunk_ms: Vec::new(),
decode_tps: None,
prefill_tps: None,
};
assert_eq!(output.text, "Hello world");
assert_eq!(output.tokens_generated, 3);
assert_eq!(output.generation_time_ms, 100);
assert!((output.tokens_per_second - 30.0).abs() < f32::EPSILON);
assert_eq!(output.finish_reason, "stop");
}
#[test]
fn test_default_streaming_implementation() {
struct MockBackend;
impl LlmBackend for MockBackend {
fn name(&self) -> &str {
"mock"
}
fn supported_formats(&self) -> Vec<&'static str> {
vec!["test"]
}
fn load(&mut self, _config: &LlmConfig) -> LlmResult<()> {
Ok(())
}
fn is_loaded(&self) -> bool {
true
}
fn unload(&mut self) -> LlmResult<()> {
Ok(())
}
fn generate(
&self,
_messages: &[ChatMessage],
_config: &GenerationConfig,
) -> LlmResult<GenerationOutput> {
Ok(GenerationOutput {
text: "Test response".to_string(),
tokens_generated: 2,
generation_time_ms: 50,
tokens_per_second: 40.0,
finish_reason: "stop".to_string(),
ttft_ms: None,
mean_itl_ms: None,
p95_itl_ms: None,
emitted_chunks: None,
inter_chunk_ms: Vec::new(),
decode_tps: None,
prefill_tps: None,
})
}
fn generate_raw(
&self,
prompt: &str,
config: &GenerationConfig,
) -> LlmResult<GenerationOutput> {
self.generate(&[ChatMessage::user(prompt)], config)
}
}
let backend = MockBackend;
assert!(!backend.supports_streaming());
let messages = vec![ChatMessage::user("test")];
let config = GenerationConfig::default();
let mut received_tokens: Vec<PartialToken> = Vec::new();
let result = backend.generate_streaming(
&messages,
&config,
Box::new(|token| {
received_tokens.push(token);
Ok(())
}),
);
assert!(result.is_ok());
let output = result.unwrap();
assert_eq!(output.text, "Test response");
assert_eq!(received_tokens.len(), 1);
assert_eq!(received_tokens[0].token, "Test response");
assert_eq!(received_tokens[0].cumulative_text, "Test response");
assert_eq!(received_tokens[0].finish_reason, Some("stop".to_string()));
assert!(received_tokens[0].is_final());
}
#[test]
fn test_default_streaming_callback_error() {
struct MockBackend;
impl LlmBackend for MockBackend {
fn name(&self) -> &str {
"mock"
}
fn supported_formats(&self) -> Vec<&'static str> {
vec!["test"]
}
fn load(&mut self, _config: &LlmConfig) -> LlmResult<()> {
Ok(())
}
fn is_loaded(&self) -> bool {
true
}
fn unload(&mut self) -> LlmResult<()> {
Ok(())
}
fn generate(
&self,
_messages: &[ChatMessage],
_config: &GenerationConfig,
) -> LlmResult<GenerationOutput> {
Ok(GenerationOutput {
text: "Test".to_string(),
tokens_generated: 1,
generation_time_ms: 10,
tokens_per_second: 100.0,
finish_reason: "stop".to_string(),
ttft_ms: None,
mean_itl_ms: None,
p95_itl_ms: None,
emitted_chunks: None,
inter_chunk_ms: Vec::new(),
decode_tps: None,
prefill_tps: None,
})
}
fn generate_raw(
&self,
prompt: &str,
config: &GenerationConfig,
) -> LlmResult<GenerationOutput> {
self.generate(&[ChatMessage::user(prompt)], config)
}
}
let backend = MockBackend;
let messages = vec![ChatMessage::user("test")];
let config = GenerationConfig::default();
let result = backend.generate_streaming(
&messages,
&config,
Box::new(|_token| Err("User cancelled".into())),
);
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("callback error") || err_msg.contains("User cancelled"));
}
}