autoagents_onnx/
chat.rs

1//! autoagents-onnx backend implementation for local model inference.
2//!
3//! This backend uses the autoagents-onnx inference runtime to run LLM models locally.
4//! It handles tokenization, text generation, and sampling specifically for LLMs.
5
6use crate::{Device, EdgeError, InferenceInput, InferenceRuntime, Model};
7use async_trait::async_trait;
8use autoagents_llm::chat::Tool;
9use autoagents_llm::models::{ModelListRequest, ModelListResponse};
10use autoagents_llm::{
11    chat::{
12        ChatMessage, ChatProvider, ChatResponse, ChatRole, MessageType, StructuredOutputFormat,
13    },
14    completion::{CompletionProvider, CompletionRequest, CompletionResponse},
15    embedding::EmbeddingProvider,
16    error::LLMError,
17    models::ModelsProvider,
18    LLMProvider, ToolCall,
19};
20use minijinja::{context, Environment};
21use serde::{Deserialize, Serialize};
22use serde_json::Value;
23use std::{path::Path, sync::Arc};
24use tokenizers::Tokenizer;
25
26/// autoagents-onnx backend for local LLM inference
27pub struct OnnxEdge {
28    inference_runtime: tokio::sync::Mutex<InferenceRuntime>,
29    tokenizer: Tokenizer,
30    model_config: ModelConfig,
31    max_tokens: u32,
32    temperature: f32,
33    top_p: f32,
34    system: Option<String>,
35    chat_template: Option<String>,
36}
37
38/// Model configuration for LLM inference
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct ModelConfig {
41    pub vocab_size: u32,
42    pub max_position_embeddings: u32,
43    pub bos_token_id: Option<u32>,
44    pub eos_token_id: Option<u32>,
45    pub pad_token_id: Option<u32>,
46}
47
48/// Generation configuration for text generation
49#[derive(Debug, Clone)]
50pub struct GenerationConfig {
51    pub max_new_tokens: u32,
52    pub temperature: f32,
53    pub top_p: f32,
54    pub do_sample: bool,
55}
56
57impl Default for GenerationConfig {
58    fn default() -> Self {
59        Self {
60            max_new_tokens: 50,
61            temperature: 0.7,
62            top_p: 0.9,
63            do_sample: true,
64        }
65    }
66}
67
68/// Response wrapper for LiquidEdge chat responses
69#[derive(Debug)]
70pub struct EdgeResponse {
71    text: String,
72}
73
74impl ChatResponse for EdgeResponse {
75    fn text(&self) -> Option<String> {
76        Some(self.text.clone())
77    }
78
79    fn tool_calls(&self) -> Option<Vec<ToolCall>> {
80        None // Tool calls not supported yet
81    }
82}
83
84impl std::fmt::Display for EdgeResponse {
85    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
86        write!(f, "{}", self.text)
87    }
88}
89
90impl OnnxEdge {
91    /// Create a new LiquidEdge instance from a model with a specific device
92    pub async fn from_model_with_device(
93        model: Box<dyn Model>,
94        device: Device,
95        _model_name: String,
96        max_tokens: Option<u32>,
97        temperature: Option<f32>,
98        top_p: Option<f32>,
99        system: Option<String>,
100    ) -> Result<Self, LLMError> {
101        let model_path = model.model_path().to_path_buf();
102
103        // Load inference runtime with the model and device
104        let inference_runtime = InferenceRuntime::from_model_with_device(model, device)
105            .await
106            .map_err(|e| LLMError::ProviderError(format!("Failed to load model: {e}")))?;
107
108        Self::from_runtime(
109            inference_runtime,
110            model_path,
111            max_tokens,
112            temperature,
113            top_p,
114            system,
115        )
116        .await
117    }
118
119    /// Create a new LiquidEdge instance from a model (uses CPU device by default)
120    pub async fn from_model(
121        model: Box<dyn Model>,
122        _model_name: String,
123        max_tokens: Option<u32>,
124        temperature: Option<f32>,
125        top_p: Option<f32>,
126        system: Option<String>,
127    ) -> Result<Self, LLMError> {
128        let model_path = model.model_path().to_path_buf();
129
130        // Load inference runtime with the model
131        let inference_runtime = InferenceRuntime::from_model(model)
132            .await
133            .map_err(|e| LLMError::ProviderError(format!("Failed to load model: {e}")))?;
134
135        Self::from_runtime(
136            inference_runtime,
137            model_path,
138            max_tokens,
139            temperature,
140            top_p,
141            system,
142        )
143        .await
144    }
145
146    /// Common initialization logic for LiquidEdge
147    async fn from_runtime(
148        inference_runtime: InferenceRuntime,
149        model_path: std::path::PathBuf,
150        max_tokens: Option<u32>,
151        temperature: Option<f32>,
152        top_p: Option<f32>,
153        system: Option<String>,
154    ) -> Result<Self, LLMError> {
155        // Load tokenizer
156        let tokenizer_path = model_path.join("tokenizer.json");
157        let tokenizer = Tokenizer::from_file(&tokenizer_path)
158            .map_err(|e| LLMError::ProviderError(format!("Failed to load tokenizer: {e}")))?;
159
160        // Load model config
161        let config_path = model_path.join("config.json");
162        let config_content = std::fs::read_to_string(&config_path)
163            .map_err(|e| LLMError::ProviderError(format!("Failed to read config.json: {e}")))?;
164
165        let config_json: Value = serde_json::from_str(&config_content)
166            .map_err(|e| LLMError::ProviderError(format!("Failed to parse config.json: {e}")))?;
167
168        let model_config = ModelConfig {
169            vocab_size: config_json
170                .get("vocab_size")
171                .and_then(|v| v.as_u64())
172                .unwrap_or(32000) as u32,
173            max_position_embeddings: config_json
174                .get("max_position_embeddings")
175                .and_then(|v| v.as_u64())
176                .unwrap_or(2048) as u32,
177            bos_token_id: config_json
178                .get("bos_token_id")
179                .and_then(|v| v.as_u64())
180                .map(|v| v as u32),
181            eos_token_id: config_json
182                .get("eos_token_id")
183                .and_then(|v| v.as_u64())
184                .map(|v| v as u32),
185            pad_token_id: config_json
186                .get("pad_token_id")
187                .and_then(|v| v.as_u64())
188                .map(|v| v as u32),
189        };
190
191        // Extract chat template from tokenizer or config
192        let chat_template = Self::load_chat_template(model_path, &config_json);
193
194        Ok(Self {
195            inference_runtime: tokio::sync::Mutex::new(inference_runtime),
196            tokenizer,
197            model_config,
198            max_tokens: max_tokens.unwrap_or(50),
199            temperature: temperature.unwrap_or(0.7),
200            top_p: top_p.unwrap_or(0.9),
201            system,
202            chat_template,
203        })
204    }
205
206    /// Load chat template from chat_template.jinja file, tokenizer.json or config.json
207    fn load_chat_template<P: AsRef<Path>>(model_path: P, config: &Value) -> Option<String> {
208        let model_path = model_path.as_ref();
209
210        // First try to load from chat_template.jinja file
211        let jinja_template_path = model_path.join("chat_template.jinja");
212        if jinja_template_path.exists() {
213            if let Ok(template_content) = std::fs::read_to_string(&jinja_template_path) {
214                log::debug!("Loaded chat template from chat_template.jinja");
215                return Some(template_content);
216            }
217        }
218
219        // Fallback to tokenizer.json
220        let tokenizer_path = model_path.join("tokenizer.json");
221        if tokenizer_path.exists() {
222            if let Ok(tokenizer_content) = std::fs::read_to_string(&tokenizer_path) {
223                if let Ok(tokenizer_json) = serde_json::from_str::<Value>(&tokenizer_content) {
224                    if let Some(chat_template) =
225                        tokenizer_json.get("chat_template").and_then(|v| v.as_str())
226                    {
227                        log::debug!("Loaded chat template from tokenizer.json");
228                        return Some(chat_template.to_string());
229                    }
230                }
231            }
232        }
233
234        // Fallback to config.json
235        if let Some(chat_template) = config.get("chat_template").and_then(|v| v.as_str()) {
236            log::debug!("Loaded chat template from config.json");
237            return Some(chat_template.to_string());
238        }
239
240        log::debug!("No chat template found");
241        None
242    }
243
244    // TODO: DO Better
245    /// Format messages into a prompt using Jinja2 chat template
246    fn format_messages(&self, messages: &[ChatMessage]) -> String {
247        // Prepare all messages including system message if provided
248        let mut all_messages = Vec::new();
249
250        // Add system message from constructor if provided and no system message in messages
251        if let Some(system) = &self.system {
252            let has_system = messages.iter().any(|m| matches!(m.role, ChatRole::System));
253            if !has_system {
254                all_messages.push(ChatMessage {
255                    role: ChatRole::System,
256                    message_type: MessageType::Text,
257                    content: system.clone(),
258                });
259            }
260        }
261
262        // Add provided messages
263        all_messages.extend_from_slice(messages);
264
265        // Use Jinja2 chat template if available
266        match self.apply_jinja_template(&all_messages) {
267            Ok(formatted) => {
268                log::debug!("Using Jinja2 chat template");
269                formatted
270            }
271            Err(e) => {
272                log::error!("Chat template required but not available or failed: {e}");
273                log::error!("Please provide a chat_template.jinja file in the model directory");
274                // Return a minimal error-indicating response
275                "Error: No chat template found. Please add chat_template.jinja file to model directory.".to_string()
276            }
277        }
278    }
279
280    /// Apply Jinja2 chat template
281    fn apply_jinja_template(&self, messages: &[ChatMessage]) -> Result<String, LLMError> {
282        let template_str = self
283            .chat_template
284            .as_ref()
285            .ok_or_else(|| LLMError::ProviderError("No chat template available".to_string()))?;
286
287        // Create Jinja2 environment
288        let mut env = Environment::new();
289
290        // Convert ChatMessage to template format
291        let template_messages: Vec<serde_json::Value> = messages
292            .iter()
293            .map(|msg| {
294                let role = match msg.role {
295                    ChatRole::System => "system",
296                    ChatRole::User => "user",
297                    ChatRole::Assistant => "assistant",
298                    ChatRole::Tool => "tool",
299                };
300
301                serde_json::json!({
302                    "role": role,
303                    "content": msg.content
304                })
305            })
306            .collect();
307
308        // Add template to environment
309        env.add_template("chat", template_str)
310            .map_err(|e| LLMError::ProviderError(format!("Failed to parse chat template: {e}")))?;
311
312        // Render template with messages
313        let template = env
314            .get_template("chat")
315            .map_err(|e| LLMError::ProviderError(format!("Failed to get chat template: {e}")))?;
316
317        let rendered = template
318            .render(context! {
319                messages => template_messages,
320                add_generation_prompt => true,
321                bos_token => "<s>",
322                eos_token => "</s>",
323                system_message => self.system.as_deref().unwrap_or(""),
324            })
325            .map_err(|e| LLMError::ProviderError(format!("Failed to render chat template: {e}")))?;
326
327        Ok(rendered)
328    }
329
330    /// Generate text using the inference runtime
331    async fn generate_text(
332        &self,
333        prompt: &str,
334        config: GenerationConfig,
335    ) -> Result<String, LLMError> {
336        // Tokenize input
337        let encoding = self
338            .tokenizer
339            .encode(prompt, true)
340            .map_err(|e| LLMError::ProviderError(format!("Tokenization failed: {e}")))?;
341        let input_tokens: Vec<i64> = encoding.get_ids().iter().map(|&x| x as i64).collect();
342
343        log::info!(
344            "Starting LLM generation with {} input tokens, max_new_tokens: {}",
345            input_tokens.len(),
346            config.max_new_tokens
347        );
348
349        let mut output_tokens = input_tokens.clone();
350        let max_length = input_tokens.len() + config.max_new_tokens as usize;
351
352        let mut runtime = self.inference_runtime.lock().await;
353
354        for step in 0..config.max_new_tokens {
355            if output_tokens.len() >= max_length {
356                log::info!("Reached max length, stopping generation");
357                break;
358            }
359
360            log::debug!("Generation step {}/{}", step + 1, config.max_new_tokens);
361
362            // Prepare inputs for inference
363            let seq_len = output_tokens.len();
364            let mut inference_input = InferenceInput::new();
365
366            // Add input_ids
367            let input_ids_json = Value::Array(
368                output_tokens
369                    .iter()
370                    .map(|&x| Value::Number(x.into()))
371                    .collect(),
372            );
373            inference_input = inference_input.add_input("input_ids".to_string(), input_ids_json);
374
375            // Add attention_mask
376            let attention_mask: Vec<Value> = vec![Value::Number(1.into()); seq_len];
377            inference_input = inference_input
378                .add_input("attention_mask".to_string(), Value::Array(attention_mask));
379
380            // Add position_ids
381            let position_ids: Vec<Value> = (0..seq_len as i64)
382                .map(|x| Value::Number(x.into()))
383                .collect();
384            inference_input =
385                inference_input.add_input("position_ids".to_string(), Value::Array(position_ids));
386
387            // Run inference
388            log::debug!("Running inference...");
389            let output = runtime
390                .infer(inference_input)
391                .map_err(|e| LLMError::ProviderError(format!("Inference failed: {e}")))?;
392            log::debug!("Inference completed");
393
394            // Get logits from output
395            let logits = output
396                .get_output("logits")
397                .ok_or_else(|| LLMError::ProviderError("No logits output found".to_string()))?;
398
399            // Extract logits for the last token
400            let logits_array = logits.as_array().ok_or_else(|| {
401                LLMError::ProviderError("Logits output is not an array".to_string())
402            })?;
403
404            // Get logits for the last token (assuming shape [batch, seq_len, vocab_size])
405            let vocab_size = self.model_config.vocab_size as usize;
406            let last_token_start = (seq_len - 1) * vocab_size;
407            let last_token_end = last_token_start + vocab_size;
408
409            if logits_array.len() < last_token_end {
410                return Err(LLMError::ProviderError("Invalid logits shape".to_string()));
411            }
412
413            let last_token_logits: Vec<f32> = logits_array[last_token_start..last_token_end]
414                .iter()
415                .map(|v| v.as_f64().unwrap_or(0.0) as f32)
416                .collect();
417
418            log::debug!(
419                "Got logits for last token, size: {}",
420                last_token_logits.len()
421            );
422
423            // Sample next token
424            let next_token = if config.do_sample {
425                self.sample_token(&last_token_logits, config.temperature, config.top_p)?
426            } else {
427                // Greedy decoding
428                last_token_logits
429                    .iter()
430                    .enumerate()
431                    .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
432                    .map(|(i, _)| i as i64)
433                    .ok_or_else(|| {
434                        LLMError::ProviderError("Failed to find max logit".to_string())
435                    })?
436            };
437
438            log::debug!("Generated token: {next_token}");
439
440            // Check for EOS token
441            if let Some(eos_id) = self.model_config.eos_token_id {
442                if next_token == eos_id as i64 {
443                    log::info!("Generated EOS token, stopping generation");
444                    break;
445                }
446            }
447
448            output_tokens.push(next_token);
449        }
450
451        log::info!(
452            "Generation completed. Total tokens: {}, generated: {}",
453            output_tokens.len(),
454            output_tokens.len() - input_tokens.len()
455        );
456
457        // Decode generated tokens (skip input tokens)
458        let generated_tokens: Vec<u32> = output_tokens[input_tokens.len()..]
459            .iter()
460            .map(|&x| x as u32)
461            .collect();
462        let generated_text = self
463            .tokenizer
464            .decode(&generated_tokens, true)
465            .map_err(|e| LLMError::ProviderError(format!("Failed to decode tokens: {e}")))?;
466
467        log::info!("Generated text: {generated_text}");
468        Ok(generated_text)
469    }
470
471    /// Sample next token using top-p sampling
472    fn sample_token(&self, logits: &[f32], temperature: f32, top_p: f32) -> Result<i64, LLMError> {
473        use rand::Rng;
474
475        // Apply temperature
476        let scaled_logits: Vec<f32> = logits.iter().map(|x| x / temperature).collect();
477
478        // Convert to probabilities (softmax)
479        let max_logit = scaled_logits
480            .iter()
481            .fold(f32::NEG_INFINITY, |a, &b| a.max(b));
482        let exp_logits: Vec<f32> = scaled_logits
483            .iter()
484            .map(|x| (x - max_logit).exp())
485            .collect();
486        let sum_exp: f32 = exp_logits.iter().sum();
487        let probs: Vec<f32> = exp_logits.iter().map(|x| x / sum_exp).collect();
488
489        // Top-p sampling
490        let mut sorted_indices: Vec<usize> = (0..probs.len()).collect();
491        sorted_indices.sort_by(|&a, &b| probs[b].partial_cmp(&probs[a]).unwrap());
492
493        let mut cumulative_prob = 0.0;
494        let mut cutoff_index = probs.len();
495
496        for (i, &idx) in sorted_indices.iter().enumerate() {
497            cumulative_prob += probs[idx];
498            if cumulative_prob >= top_p {
499                cutoff_index = i + 1;
500                break;
501            }
502        }
503
504        // Sample from the top-p distribution
505        let mut rng = rand::rng();
506        let random_value: f32 = rng.random();
507
508        let mut cumulative = 0.0;
509        for &idx in sorted_indices.iter().take(cutoff_index) {
510            cumulative += probs[idx];
511            if random_value <= cumulative {
512                return Ok(idx as i64);
513            }
514        }
515
516        // Fallback to the most probable token
517        Ok(sorted_indices[0] as i64)
518    }
519}
520
521#[async_trait]
522impl ChatProvider for OnnxEdge {
523    async fn chat(
524        &self,
525        messages: &[ChatMessage],
526        _tools: Option<&[Tool]>,
527        json_schema: Option<StructuredOutputFormat>,
528    ) -> Result<Box<dyn ChatResponse>, LLMError> {
529        let mut modified_messages = messages.to_vec();
530
531        // Add JSON format instruction if schema is provided
532        if let Some(schema) = &json_schema {
533            let default_schema = serde_json::json!({});
534            let schema_json = schema.schema.as_ref().unwrap_or(&default_schema);
535            let schema_str =
536                serde_json::to_string_pretty(schema_json).unwrap_or_else(|_| "{}".to_string());
537
538            //TODO: Improve
539            let json_instruction = format!(
540                "You must respond with valid JSON that matches this schema: {schema_str}. Only return the JSON, no additional text.").to_string();
541
542            modified_messages.insert(
543                0,
544                ChatMessage {
545                    role: ChatRole::System,
546                    message_type: MessageType::Text,
547                    content: json_instruction,
548                },
549            );
550        }
551
552        let prompt = self.format_messages(&modified_messages);
553
554        let generation_config = GenerationConfig {
555            max_new_tokens: self.max_tokens,
556            temperature: self.temperature,
557            top_p: self.top_p,
558            do_sample: true,
559        };
560
561        let response_text = self.generate_text(&prompt, generation_config).await?;
562        let cleaned_response = response_text.trim().to_string();
563
564        Ok(Box::new(EdgeResponse {
565            text: if cleaned_response.is_empty() {
566                "I'm here to help! What would you like to know?".to_string()
567            } else {
568                cleaned_response
569            },
570        }))
571    }
572}
573
574#[async_trait]
575impl CompletionProvider for OnnxEdge {
576    async fn complete(
577        &self,
578        req: &CompletionRequest,
579        _json_schema: Option<StructuredOutputFormat>,
580    ) -> Result<CompletionResponse, LLMError> {
581        let generation_config = GenerationConfig {
582            max_new_tokens: self.max_tokens,
583            temperature: self.temperature,
584            top_p: self.top_p,
585            do_sample: true,
586        };
587
588        let text = self.generate_text(&req.prompt, generation_config).await?;
589        Ok(CompletionResponse { text })
590    }
591}
592
593#[async_trait]
594impl EmbeddingProvider for OnnxEdge {
595    async fn embed(&self, _input: Vec<String>) -> Result<Vec<Vec<f32>>, LLMError> {
596        Err(LLMError::ProviderError(
597            "Embedding not supported by LiquidEdge backend".to_string(),
598        ))
599    }
600}
601
602#[async_trait]
603impl ModelsProvider for OnnxEdge {
604    async fn list_models(
605        &self,
606        _request: Option<&ModelListRequest>,
607    ) -> Result<Box<dyn ModelListResponse>, LLMError> {
608        Err(LLMError::ProviderError(
609            "Model listing not supported by LiquidEdge backend".to_string(),
610        ))
611    }
612}
613
614impl LLMProvider for OnnxEdge {}
615
616#[derive(Debug, Default)]
617pub struct LiquidEdgeBuilder {
618    model: Option<Box<dyn Model>>,
619    device: Option<Device>,
620    max_tokens: Option<u32>,
621    temperature: Option<f32>,
622    top_p: Option<f32>,
623    system: Option<String>,
624}
625
626impl LiquidEdgeBuilder {
627    pub fn new() -> Self {
628        Self::default()
629    }
630
631    pub fn model(mut self, model: Box<dyn Model>) -> Self {
632        self.model = Some(model);
633        self
634    }
635
636    pub fn device(mut self, device: Device) -> Self {
637        self.device = Some(device);
638        self
639    }
640
641    pub fn max_tokens(mut self, max_tokens: u32) -> Self {
642        self.max_tokens = Some(max_tokens);
643        self
644    }
645
646    pub fn temperature(mut self, temperature: f32) -> Self {
647        self.temperature = Some(temperature);
648        self
649    }
650
651    pub fn top_p(mut self, top_p: f32) -> Self {
652        self.top_p = Some(top_p);
653        self
654    }
655
656    pub fn system(mut self, system: &str) -> Self {
657        self.system = Some(system.to_string());
658        self
659    }
660
661    pub async fn build(self) -> Result<Arc<OnnxEdge>, LLMError> {
662        let liquid_edge = if let Some(model) = self.model {
663            if let Some(device) = self.device {
664                OnnxEdge::from_model_with_device(
665                    model,
666                    device,
667                    "onnx-ort-model".to_string(),
668                    self.max_tokens,
669                    self.temperature,
670                    self.top_p,
671                    self.system,
672                )
673                .await?
674            } else {
675                OnnxEdge::from_model(
676                    model,
677                    "onnx-ort-model".to_string(),
678                    self.max_tokens,
679                    self.temperature,
680                    self.top_p,
681                    self.system,
682                )
683                .await?
684            }
685        } else {
686            return Err(LLMError::InvalidRequest(
687                "edge_model must be provided for LiquidEdge".to_string(),
688            ));
689        };
690
691        Ok(Arc::new(liquid_edge))
692    }
693}
694
695// Convert EdgeError to LLMError
696impl From<EdgeError> for LLMError {
697    fn from(err: EdgeError) -> Self {
698        LLMError::ProviderError(format!("LiquidEdge error: {err}"))
699    }
700}