use serde::{Deserialize, Serialize};
use swarm_engine_core::types::LoraConfig;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum LlmProvider {
#[default]
Ollama,
#[serde(alias = "openai")]
OpenAI,
Anthropic,
#[serde(alias = "vllm")]
VLLM,
#[serde(alias = "mistralrs")]
Mistral,
#[serde(alias = "llama-cpp", alias = "llamacpp")]
LlamaCpp,
#[serde(alias = "llama-server", alias = "llamaserver")]
LlamaCppServer,
}
impl LlmProvider {
pub fn default_endpoint(&self) -> Option<&'static str> {
match self {
LlmProvider::Ollama => Some("http://localhost:11434"),
LlmProvider::OpenAI => Some("https://api.openai.com/v1"),
LlmProvider::Anthropic => Some("https://api.anthropic.com"),
LlmProvider::VLLM => Some("http://localhost:8000"),
LlmProvider::Mistral => None, LlmProvider::LlamaCpp => None, LlmProvider::LlamaCppServer => Some("http://localhost:8080"),
}
}
pub fn requires_endpoint(&self) -> bool {
!matches!(self, LlmProvider::Mistral | LlmProvider::LlamaCpp)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LlmConfig {
#[serde(default)]
pub provider: LlmProvider,
pub model: String,
#[serde(default)]
pub endpoint: Option<String>,
#[serde(default = "default_temperature")]
pub temperature: f32,
#[serde(default)]
pub system_prompt: Option<String>,
#[serde(default)]
pub max_tokens: Option<usize>,
#[serde(default = "default_timeout_ms")]
pub timeout_ms: u64,
#[serde(default)]
pub num_ctx: Option<usize>,
#[serde(default)]
pub num_predict: Option<usize>,
#[serde(default)]
pub gguf_files: Vec<String>,
#[serde(default = "default_true")]
pub paged_attention: bool,
#[serde(default)]
pub quantization: Option<String>,
#[serde(default)]
pub lora: Option<LoraConfig>,
#[serde(default)]
pub chat_template: Option<String>,
}
fn default_temperature() -> f32 {
0.1
}
fn default_timeout_ms() -> u64 {
30000
}
pub(crate) fn default_true() -> bool {
true
}
impl Default for LlmConfig {
fn default() -> Self {
Self {
provider: LlmProvider::default(),
model: "hf.co/LiquidAI/LFM2.5-1.2B-Instruct-GGUF:Q4_K_M".to_string(),
endpoint: LlmProvider::default().default_endpoint().map(String::from),
temperature: default_temperature(),
system_prompt: None,
max_tokens: None,
timeout_ms: default_timeout_ms(),
num_ctx: None,
num_predict: None,
gguf_files: vec![],
paged_attention: true,
quantization: None,
lora: None,
chat_template: None,
}
}
}
impl LlmConfig {
pub fn mistral_gguf(model_id: impl Into<String>, files: Vec<impl Into<String>>) -> Self {
Self {
provider: LlmProvider::Mistral,
model: model_id.into(),
endpoint: None,
temperature: default_temperature(),
system_prompt: None,
max_tokens: Some(256),
timeout_ms: default_timeout_ms(),
num_ctx: None,
num_predict: None,
gguf_files: files.into_iter().map(|f| f.into()).collect(),
paged_attention: true,
quantization: None,
lora: None,
chat_template: None,
}
}
pub fn mistral_hf(model_id: impl Into<String>) -> Self {
Self {
provider: LlmProvider::Mistral,
model: model_id.into(),
endpoint: None,
temperature: default_temperature(),
system_prompt: None,
max_tokens: Some(256),
timeout_ms: default_timeout_ms(),
num_ctx: None,
num_predict: None,
gguf_files: vec![],
paged_attention: true,
quantization: Some("q4k".to_string()),
lora: None,
chat_template: None,
}
}
pub fn liquid_lfm_1b() -> Self {
Self::mistral_gguf(
"LiquidAI/LFM2.5-1.2B-Instruct-GGUF",
vec!["LFM2.5-1.2B-Instruct-Q4_K_M.gguf"],
)
}
pub fn phi3_mini() -> Self {
Self::mistral_hf("microsoft/Phi-3.5-mini-instruct")
}
pub fn qwen_1b() -> Self {
Self::mistral_hf("Qwen/Qwen2.5-1.5B-Instruct")
}
pub fn is_gguf(&self) -> bool {
!self.gguf_files.is_empty()
}
}
impl LlmConfig {
pub fn to_llm_decider_config(
&self,
max_batch_size: usize,
) -> swarm_engine_llm::LlmDeciderConfig {
swarm_engine_llm::LlmDeciderConfig {
model: self.model.clone(),
endpoint: self
.endpoint
.clone()
.unwrap_or_else(|| self.provider.default_endpoint().unwrap_or("").to_string()),
timeout_ms: self.timeout_ms,
max_batch_size,
temperature: self.temperature,
system_prompt: self.system_prompt.clone(),
}
}
pub fn to_ollama_config(&self, max_batch_size: usize) -> swarm_engine_llm::OllamaConfig {
swarm_engine_llm::OllamaConfig {
base: self.to_llm_decider_config(max_batch_size),
num_predict: self.num_predict.unwrap_or(256),
num_ctx: self.num_ctx.unwrap_or(4096),
}
}
pub fn to_chat_template(&self) -> swarm_engine_llm::ChatTemplate {
match self.chat_template.as_deref() {
Some("lfm2") | Some("lfm") => swarm_engine_llm::ChatTemplate::Lfm2,
Some("qwen") => swarm_engine_llm::ChatTemplate::Qwen,
Some("llama3") | Some("llama") => swarm_engine_llm::ChatTemplate::Llama3,
_ => swarm_engine_llm::ChatTemplate::Lfm2, }
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct LlmConfigOverride {
#[serde(default)]
pub provider: Option<LlmProvider>,
#[serde(default)]
pub model: Option<String>,
#[serde(default)]
pub endpoint: Option<String>,
#[serde(default)]
pub temperature: Option<f32>,
#[serde(default)]
pub system_prompt: Option<String>,
#[serde(default)]
pub max_tokens: Option<usize>,
#[serde(default)]
pub timeout_ms: Option<u64>,
#[serde(default)]
pub num_ctx: Option<usize>,
#[serde(default)]
pub num_predict: Option<usize>,
#[serde(default)]
pub gguf_files: Option<Vec<String>>,
#[serde(default)]
pub paged_attention: Option<bool>,
#[serde(default)]
pub quantization: Option<String>,
#[serde(default)]
pub lora: Option<LoraConfig>,
#[serde(default)]
pub chat_template: Option<String>,
}
macro_rules! apply_override {
($base:expr, $override:expr, copy: $($field:ident),* $(,)?) => {
$(
if let Some(val) = $override.$field {
$base.$field = val;
}
)*
};
($base:expr, $override:expr, clone: $($field:ident),* $(,)?) => {
$(
if let Some(ref val) = $override.$field {
$base.$field = val.clone();
}
)*
};
($base:expr, $override:expr, option_copy: $($field:ident),* $(,)?) => {
$(
if let Some(val) = $override.$field {
$base.$field = Some(val);
}
)*
};
($base:expr, $override:expr, option_clone: $($field:ident),* $(,)?) => {
$(
if let Some(ref val) = $override.$field {
$base.$field = Some(val.clone());
}
)*
};
}
impl LlmConfigOverride {
pub fn apply_to(&self, base: &mut LlmConfig) {
apply_override!(base, self, copy: provider, temperature, timeout_ms, paged_attention);
apply_override!(base, self, clone: model, gguf_files);
apply_override!(base, self, option_copy: max_tokens, num_ctx, num_predict);
apply_override!(base, self, option_clone: endpoint, system_prompt, quantization, lora, chat_template);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_llm_provider_default_endpoint() {
assert_eq!(
LlmProvider::Ollama.default_endpoint(),
Some("http://localhost:11434")
);
assert_eq!(LlmProvider::Mistral.default_endpoint(), None);
assert_eq!(LlmProvider::LlamaCpp.default_endpoint(), None);
assert_eq!(
LlmProvider::LlamaCppServer.default_endpoint(),
Some("http://localhost:8080")
);
}
#[test]
fn test_llm_provider_deserialize_variants() {
let test_cases = [
(r#""ollama""#, LlmProvider::Ollama),
(r#""openai""#, LlmProvider::OpenAI),
(r#""vllm""#, LlmProvider::VLLM),
(r#""mistral""#, LlmProvider::Mistral),
(r#""mistralrs""#, LlmProvider::Mistral),
(r#""llamacpp""#, LlmProvider::LlamaCpp),
(r#""llama-cpp""#, LlmProvider::LlamaCpp),
(r#""llamacppserver""#, LlmProvider::LlamaCppServer),
(r#""llama-server""#, LlmProvider::LlamaCppServer),
];
for (json, expected) in test_cases {
let provider: LlmProvider = serde_json::from_str(json).unwrap();
assert_eq!(provider, expected, "Failed for input: {}", json);
}
}
#[test]
fn test_llm_config_default() {
let config = LlmConfig::default();
assert_eq!(config.provider, LlmProvider::Ollama);
assert!(config.endpoint.is_some());
assert!((config.temperature - 0.1).abs() < 0.001);
}
#[test]
fn test_llm_config_mistral_gguf() {
let config =
LlmConfig::mistral_gguf("LiquidAI/LFM2.5-1.2B-Instruct-GGUF", vec!["test.gguf"]);
assert_eq!(config.provider, LlmProvider::Mistral);
assert!(config.endpoint.is_none());
assert_eq!(config.gguf_files, vec!["test.gguf"]);
assert!(config.is_gguf());
}
#[test]
fn test_llm_config_mistral_hf() {
let config = LlmConfig::mistral_hf("microsoft/Phi-3.5-mini-instruct");
assert_eq!(config.provider, LlmProvider::Mistral);
assert!(config.gguf_files.is_empty());
assert_eq!(config.quantization, Some("q4k".to_string()));
assert!(!config.is_gguf());
}
#[test]
fn test_llm_config_presets() {
let liquid = LlmConfig::liquid_lfm_1b();
assert!(liquid.is_gguf());
let phi = LlmConfig::phi3_mini();
assert!(!phi.is_gguf());
assert!(phi.quantization.is_some());
let qwen = LlmConfig::qwen_1b();
assert!(!qwen.is_gguf());
}
#[test]
fn test_llm_config_to_ollama_config() {
let config = LlmConfig::default();
let ollama_config = config.to_ollama_config(4);
assert_eq!(ollama_config.base.max_batch_size, 4);
assert_eq!(ollama_config.num_ctx, 4096);
}
#[test]
fn test_llm_config_to_llm_decider_config() {
let config = LlmConfig::default();
let decider_config = config.to_llm_decider_config(8);
assert_eq!(decider_config.max_batch_size, 8);
assert!((decider_config.temperature - 0.1).abs() < 0.001);
}
#[test]
fn test_llm_config_override_apply() {
let mut base = LlmConfig::default();
let override_config = LlmConfigOverride {
model: Some("new-model".to_string()),
temperature: Some(0.5),
max_tokens: Some(1024),
..Default::default()
};
override_config.apply_to(&mut base);
assert_eq!(base.model, "new-model");
assert!((base.temperature - 0.5).abs() < 0.001);
assert_eq!(base.max_tokens, Some(1024));
assert_eq!(base.provider, LlmProvider::Ollama);
}
#[test]
fn test_llm_config_deserialize_toml() {
let toml_str = r#"
provider = "ollama"
model = "qwen2.5-coder:7b"
temperature = 0.2
"#;
let config: LlmConfig = toml::from_str(toml_str).unwrap();
assert_eq!(config.provider, LlmProvider::Ollama);
assert_eq!(config.model, "qwen2.5-coder:7b");
assert!((config.temperature - 0.2).abs() < 0.001);
}
#[test]
fn test_llm_config_extended_fields_deserialize_toml() {
let toml_str = r#"
provider = "mistral"
model = "LiquidAI/LFM2.5-1.2B-Instruct-GGUF"
gguf_files = ["LFM2.5-1.2B-Instruct-Q4_K_M.gguf"]
paged_attention = false
quantization = "q8_0"
"#;
let config: LlmConfig = toml::from_str(toml_str).unwrap();
assert_eq!(config.provider, LlmProvider::Mistral);
assert_eq!(config.gguf_files, vec!["LFM2.5-1.2B-Instruct-Q4_K_M.gguf"]);
assert!(!config.paged_attention);
assert_eq!(config.quantization, Some("q8_0".to_string()));
}
#[test]
fn test_to_chat_template() {
use swarm_engine_llm::ChatTemplate;
let mut config = LlmConfig::default();
assert!(matches!(config.to_chat_template(), ChatTemplate::Lfm2));
config.chat_template = Some("lfm2".to_string());
assert!(matches!(config.to_chat_template(), ChatTemplate::Lfm2));
config.chat_template = Some("qwen".to_string());
assert!(matches!(config.to_chat_template(), ChatTemplate::Qwen));
config.chat_template = Some("llama3".to_string());
assert!(matches!(config.to_chat_template(), ChatTemplate::Llama3));
config.chat_template = Some("unknown".to_string());
assert!(matches!(config.to_chat_template(), ChatTemplate::Lfm2));
}
#[test]
fn test_chat_template_deserialize_toml() {
let toml_str = r#"
provider = "llama-server"
model = "qwen2.5"
chat_template = "qwen"
"#;
let config: LlmConfig = toml::from_str(toml_str).unwrap();
assert_eq!(config.chat_template, Some("qwen".to_string()));
use swarm_engine_llm::ChatTemplate;
assert!(matches!(config.to_chat_template(), ChatTemplate::Qwen));
}
}