use std::time::Duration;
use nng::options::{Options, RecvTimeout, SendTimeout};
use nng::{Protocol, Socket};
use super::protocol::*;
use crate::error::MullamaError;
pub struct DaemonClient {
socket: Socket,
timeout: Duration,
}
impl DaemonClient {
pub fn connect_default() -> Result<Self, MullamaError> {
Self::connect(super::DEFAULT_SOCKET)
}
pub fn connect(addr: &str) -> Result<Self, MullamaError> {
Self::connect_with_timeout(addr, Duration::from_secs(5))
}
pub fn connect_with_timeout(addr: &str, timeout: Duration) -> Result<Self, MullamaError> {
let socket = Socket::new(Protocol::Req0)
.map_err(|e| MullamaError::DaemonError(format!("Failed to create socket: {}", e)))?;
socket
.set_opt::<RecvTimeout>(Some(timeout))
.map_err(|e| MullamaError::DaemonError(format!("Failed to set recv timeout: {}", e)))?;
socket
.set_opt::<SendTimeout>(Some(timeout))
.map_err(|e| MullamaError::DaemonError(format!("Failed to set send timeout: {}", e)))?;
socket.dial(addr).map_err(|e| {
MullamaError::DaemonError(format!("Failed to connect to {}: {}", addr, e))
})?;
Ok(Self { socket, timeout })
}
pub fn request(&self, request: &Request) -> Result<Response, MullamaError> {
self.request_with_timeout(request, self.timeout)
}
pub fn request_with_timeout(
&self,
request: &Request,
timeout: Duration,
) -> Result<Response, MullamaError> {
self.socket
.set_opt::<RecvTimeout>(Some(timeout))
.map_err(|e| MullamaError::DaemonError(format!("Failed to set timeout: {}", e)))?;
let req_bytes = request
.to_bytes()
.map_err(|e| MullamaError::DaemonError(format!("Serialization failed: {}", e)))?;
self.socket
.send(nng::Message::from(req_bytes.as_slice()))
.map_err(|(_, e)| MullamaError::DaemonError(format!("Send failed: {}", e)))?;
let msg = self.socket.recv().map_err(|e| {
if e == nng::Error::TimedOut {
MullamaError::DaemonError("Request timed out - daemon may have crashed".to_string())
} else {
MullamaError::DaemonError(format!("Receive failed: {}", e))
}
})?;
Response::from_bytes(&msg)
.map_err(|e| MullamaError::DaemonError(format!("Deserialization failed: {}", e)))
}
pub fn ping(&self) -> Result<(u64, String), MullamaError> {
match self.request(&Request::Ping)? {
Response::Pong {
uptime_secs,
version,
} => Ok((uptime_secs, version)),
Response::Error { message, .. } => Err(MullamaError::DaemonError(message)),
_ => Err(MullamaError::DaemonError("Unexpected response".into())),
}
}
pub fn status(&self) -> Result<DaemonStatus, MullamaError> {
match self.request(&Request::Status)? {
Response::Status(status) => Ok(status),
Response::Error { message, .. } => Err(MullamaError::DaemonError(message)),
_ => Err(MullamaError::DaemonError("Unexpected response".into())),
}
}
pub fn list_models(&self) -> Result<Vec<ModelStatus>, MullamaError> {
match self.request(&Request::ListModels)? {
Response::Models(models) => Ok(models),
Response::Error { message, .. } => Err(MullamaError::DaemonError(message)),
_ => Err(MullamaError::DaemonError("Unexpected response".into())),
}
}
pub fn load_model(&self, spec: &str) -> Result<(String, ModelInfo), MullamaError> {
let (alias, path) = if let Some(pos) = spec.find(':') {
(spec[..pos].to_string(), spec[pos + 1..].to_string())
} else {
let p = std::path::Path::new(spec);
let alias = p
.file_stem()
.map(|s| s.to_string_lossy().to_string())
.unwrap_or_else(|| "model".to_string());
(alias, spec.to_string())
};
match self.request(&Request::LoadModel(ModelLoadParams {
alias: alias.clone(),
path,
gpu_layers: 0,
context_size: 0,
use_mmap: None,
use_mlock: false,
flash_attn: false,
cache_type_k: None,
cache_type_v: None,
rope_freq_base: None,
rope_freq_scale: None,
n_batch: None,
defrag_thold: None,
split_mode: None,
}))? {
Response::ModelLoaded { alias, info } => Ok((alias, info)),
Response::Error { message, .. } => Err(MullamaError::DaemonError(message)),
_ => Err(MullamaError::DaemonError("Unexpected response".into())),
}
}
pub fn load_model_with_options(
&self,
alias: &str,
path: &str,
gpu_layers: i32,
context_size: u32,
) -> Result<(String, ModelInfo), MullamaError> {
match self.request(&Request::LoadModel(ModelLoadParams {
alias: alias.to_string(),
path: path.to_string(),
gpu_layers,
context_size,
use_mmap: None,
use_mlock: false,
flash_attn: false,
cache_type_k: None,
cache_type_v: None,
rope_freq_base: None,
rope_freq_scale: None,
n_batch: None,
defrag_thold: None,
split_mode: None,
}))? {
Response::ModelLoaded { alias, info } => Ok((alias, info)),
Response::Error { message, .. } => Err(MullamaError::DaemonError(message)),
_ => Err(MullamaError::DaemonError("Unexpected response".into())),
}
}
pub fn load_model_full(
&self,
alias: &str,
path: &str,
gpu_layers: i32,
context_size: u32,
flash_attn: bool,
cache_type_k: Option<String>,
cache_type_v: Option<String>,
use_mmap: Option<bool>,
use_mlock: bool,
n_batch: Option<u32>,
) -> Result<(String, ModelInfo), MullamaError> {
match self.request(&Request::LoadModel(ModelLoadParams {
alias: alias.to_string(),
path: path.to_string(),
gpu_layers,
context_size,
use_mmap,
use_mlock,
flash_attn,
cache_type_k,
cache_type_v,
rope_freq_base: None,
rope_freq_scale: None,
n_batch,
defrag_thold: None,
split_mode: None,
}))? {
Response::ModelLoaded { alias, info } => Ok((alias, info)),
Response::Error { message, .. } => Err(MullamaError::DaemonError(message)),
_ => Err(MullamaError::DaemonError("Unexpected response".into())),
}
}
pub fn unload_model(&self, alias: &str) -> Result<(), MullamaError> {
match self.request(&Request::UnloadModel {
alias: alias.to_string(),
})? {
Response::ModelUnloaded { .. } => Ok(()),
Response::Error { message, .. } => Err(MullamaError::DaemonError(message)),
_ => Err(MullamaError::DaemonError("Unexpected response".into())),
}
}
pub fn set_default_model(&self, alias: &str) -> Result<(), MullamaError> {
match self.request(&Request::SetDefaultModel {
alias: alias.to_string(),
})? {
Response::DefaultModelSet { .. } => Ok(()),
Response::Error { message, .. } => Err(MullamaError::DaemonError(message)),
_ => Err(MullamaError::DaemonError("Unexpected response".into())),
}
}
pub fn chat(
&self,
message: &str,
model: Option<&str>,
max_tokens: u32,
temperature: f32,
) -> Result<ChatResult, MullamaError> {
let messages = vec![ChatMessage {
role: "user".to_string(),
content: message.to_string().into(),
name: None,
tool_calls: None,
tool_call_id: None,
}];
self.chat_completion(messages, model, max_tokens, temperature)
}
pub fn chat_completion(
&self,
messages: Vec<ChatMessage>,
model: Option<&str>,
max_tokens: u32,
temperature: f32,
) -> Result<ChatResult, MullamaError> {
let start = std::time::Instant::now();
let generation_timeout = Duration::from_secs(300);
match self.request_with_timeout(
&Request::ChatCompletion(ChatCompletionParams {
model: model.map(String::from),
messages,
max_tokens,
temperature: Some(temperature),
top_p: None,
top_k: None,
frequency_penalty: None,
presence_penalty: None,
stream: false,
stop: vec![],
response_format: None,
tools: None,
tool_choice: None,
thinking: None,
}),
generation_timeout,
)? {
Response::ChatCompletion(resp) => Ok(ChatResult {
text: resp
.choices
.first()
.map(|c| c.message.content.text())
.unwrap_or_default(),
model: resp.model,
prompt_tokens: resp.usage.prompt_tokens,
completion_tokens: resp.usage.completion_tokens,
duration_ms: start.elapsed().as_millis() as u64,
}),
Response::Error { message, .. } => Err(MullamaError::DaemonError(message)),
_ => Err(MullamaError::DaemonError("Unexpected response".into())),
}
}
pub fn complete(
&self,
prompt: &str,
model: Option<&str>,
max_tokens: u32,
temperature: f32,
) -> Result<CompletionResult, MullamaError> {
let start = std::time::Instant::now();
let generation_timeout = Duration::from_secs(300);
match self.request_with_timeout(
&Request::Completion(CompletionParams {
model: model.map(String::from),
prompt: prompt.to_string(),
max_tokens,
temperature: Some(temperature),
top_p: None,
top_k: None,
frequency_penalty: None,
presence_penalty: None,
stream: false,
stop: vec![],
}),
generation_timeout,
)? {
Response::Completion(resp) => Ok(CompletionResult {
text: resp
.choices
.first()
.map(|c| c.text.clone())
.unwrap_or_default(),
model: resp.model,
prompt_tokens: resp.usage.prompt_tokens,
completion_tokens: resp.usage.completion_tokens,
duration_ms: start.elapsed().as_millis() as u64,
}),
Response::Error { message, .. } => Err(MullamaError::DaemonError(message)),
_ => Err(MullamaError::DaemonError("Unexpected response".into())),
}
}
pub fn tokenize(&self, text: &str, model: Option<&str>) -> Result<Vec<i32>, MullamaError> {
match self.request(&Request::Tokenize {
model: model.map(String::from),
text: text.to_string(),
})? {
Response::Tokens { tokens, .. } => Ok(tokens),
Response::Error { message, .. } => Err(MullamaError::DaemonError(message)),
_ => Err(MullamaError::DaemonError("Unexpected response".into())),
}
}
pub fn embed(&self, text: &str, model: Option<&str>) -> Result<EmbeddingResult, MullamaError> {
let embedding_timeout = Duration::from_secs(60);
match self.request_with_timeout(
&Request::Embeddings {
model: model.map(String::from),
input: EmbeddingInput::Single(text.to_string()),
},
embedding_timeout,
)? {
Response::Embeddings(resp) => {
let embedding = resp
.data
.into_iter()
.next()
.map(|d| d.embedding)
.unwrap_or_default();
Ok(EmbeddingResult {
embedding,
model: resp.model,
prompt_tokens: resp.usage.prompt_tokens,
})
}
Response::Error { message, .. } => Err(MullamaError::DaemonError(message)),
_ => Err(MullamaError::DaemonError("Unexpected response".into())),
}
}
pub fn embed_batch(
&self,
texts: &[&str],
model: Option<&str>,
) -> Result<BatchEmbeddingResult, MullamaError> {
let embedding_timeout = Duration::from_secs(300);
let input = EmbeddingInput::Multiple(texts.iter().map(|s| s.to_string()).collect());
match self.request_with_timeout(
&Request::Embeddings {
model: model.map(String::from),
input,
},
embedding_timeout,
)? {
Response::Embeddings(resp) => {
let embeddings = resp.data.into_iter().map(|d| d.embedding).collect();
Ok(BatchEmbeddingResult {
embeddings,
model: resp.model,
prompt_tokens: resp.usage.prompt_tokens,
})
}
Response::Error { message, .. } => Err(MullamaError::DaemonError(message)),
_ => Err(MullamaError::DaemonError("Unexpected response".into())),
}
}
pub fn shutdown(&self) -> Result<(), MullamaError> {
match self.request(&Request::Shutdown)? {
Response::ShuttingDown => Ok(()),
Response::Error { message, .. } => Err(MullamaError::DaemonError(message)),
_ => Err(MullamaError::DaemonError("Unexpected response".into())),
}
}
}
#[derive(Debug, Clone)]
pub struct ChatResult {
pub text: String,
pub model: String,
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub duration_ms: u64,
}
impl ChatResult {
pub fn tokens_per_second(&self) -> f64 {
if self.duration_ms == 0 {
0.0
} else {
(self.completion_tokens as f64) / (self.duration_ms as f64 / 1000.0)
}
}
}
#[derive(Debug, Clone)]
pub struct CompletionResult {
pub text: String,
pub model: String,
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub duration_ms: u64,
}
impl CompletionResult {
pub fn tokens_per_second(&self) -> f64 {
if self.duration_ms == 0 {
0.0
} else {
(self.completion_tokens as f64) / (self.duration_ms as f64 / 1000.0)
}
}
}
#[derive(Debug, Clone)]
pub struct EmbeddingResult {
pub embedding: Vec<f32>,
pub model: String,
pub prompt_tokens: u32,
}
impl EmbeddingResult {
pub fn dimension(&self) -> usize {
self.embedding.len()
}
}
#[derive(Debug, Clone)]
pub struct BatchEmbeddingResult {
pub embeddings: Vec<Vec<f32>>,
pub model: String,
pub prompt_tokens: u32,
}
impl BatchEmbeddingResult {
pub fn count(&self) -> usize {
self.embeddings.len()
}
pub fn dimension(&self) -> usize {
self.embeddings.first().map(|e| e.len()).unwrap_or(0)
}
}