1use crate::{Device, EdgeError, InferenceInput, InferenceRuntime, Model};
7use async_trait::async_trait;
8use autoagents_llm::chat::Tool;
9use autoagents_llm::models::{ModelListRequest, ModelListResponse};
10use autoagents_llm::{
11 chat::{
12 ChatMessage, ChatProvider, ChatResponse, ChatRole, MessageType, StructuredOutputFormat,
13 },
14 completion::{CompletionProvider, CompletionRequest, CompletionResponse},
15 embedding::EmbeddingProvider,
16 error::LLMError,
17 models::ModelsProvider,
18 LLMProvider, ToolCall,
19};
20use minijinja::{context, Environment};
21use serde::{Deserialize, Serialize};
22use serde_json::Value;
23use std::{path::Path, sync::Arc};
24use tokenizers::Tokenizer;
25
26pub struct OnnxEdge {
28 inference_runtime: tokio::sync::Mutex<InferenceRuntime>,
29 tokenizer: Tokenizer,
30 model_config: ModelConfig,
31 max_tokens: u32,
32 temperature: f32,
33 top_p: f32,
34 system: Option<String>,
35 chat_template: Option<String>,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct ModelConfig {
41 pub vocab_size: u32,
42 pub max_position_embeddings: u32,
43 pub bos_token_id: Option<u32>,
44 pub eos_token_id: Option<u32>,
45 pub pad_token_id: Option<u32>,
46}
47
48#[derive(Debug, Clone)]
50pub struct GenerationConfig {
51 pub max_new_tokens: u32,
52 pub temperature: f32,
53 pub top_p: f32,
54 pub do_sample: bool,
55}
56
57impl Default for GenerationConfig {
58 fn default() -> Self {
59 Self {
60 max_new_tokens: 50,
61 temperature: 0.7,
62 top_p: 0.9,
63 do_sample: true,
64 }
65 }
66}
67
68#[derive(Debug)]
70pub struct EdgeResponse {
71 text: String,
72}
73
74impl ChatResponse for EdgeResponse {
75 fn text(&self) -> Option<String> {
76 Some(self.text.clone())
77 }
78
79 fn tool_calls(&self) -> Option<Vec<ToolCall>> {
80 None }
82}
83
84impl std::fmt::Display for EdgeResponse {
85 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
86 write!(f, "{}", self.text)
87 }
88}
89
90impl OnnxEdge {
91 pub async fn from_model_with_device(
93 model: Box<dyn Model>,
94 device: Device,
95 _model_name: String,
96 max_tokens: Option<u32>,
97 temperature: Option<f32>,
98 top_p: Option<f32>,
99 system: Option<String>,
100 ) -> Result<Self, LLMError> {
101 let model_path = model.model_path().to_path_buf();
102
103 let inference_runtime = InferenceRuntime::from_model_with_device(model, device)
105 .await
106 .map_err(|e| LLMError::ProviderError(format!("Failed to load model: {e}")))?;
107
108 Self::from_runtime(
109 inference_runtime,
110 model_path,
111 max_tokens,
112 temperature,
113 top_p,
114 system,
115 )
116 .await
117 }
118
119 pub async fn from_model(
121 model: Box<dyn Model>,
122 _model_name: String,
123 max_tokens: Option<u32>,
124 temperature: Option<f32>,
125 top_p: Option<f32>,
126 system: Option<String>,
127 ) -> Result<Self, LLMError> {
128 let model_path = model.model_path().to_path_buf();
129
130 let inference_runtime = InferenceRuntime::from_model(model)
132 .await
133 .map_err(|e| LLMError::ProviderError(format!("Failed to load model: {e}")))?;
134
135 Self::from_runtime(
136 inference_runtime,
137 model_path,
138 max_tokens,
139 temperature,
140 top_p,
141 system,
142 )
143 .await
144 }
145
146 async fn from_runtime(
148 inference_runtime: InferenceRuntime,
149 model_path: std::path::PathBuf,
150 max_tokens: Option<u32>,
151 temperature: Option<f32>,
152 top_p: Option<f32>,
153 system: Option<String>,
154 ) -> Result<Self, LLMError> {
155 let tokenizer_path = model_path.join("tokenizer.json");
157 let tokenizer = Tokenizer::from_file(&tokenizer_path)
158 .map_err(|e| LLMError::ProviderError(format!("Failed to load tokenizer: {e}")))?;
159
160 let config_path = model_path.join("config.json");
162 let config_content = std::fs::read_to_string(&config_path)
163 .map_err(|e| LLMError::ProviderError(format!("Failed to read config.json: {e}")))?;
164
165 let config_json: Value = serde_json::from_str(&config_content)
166 .map_err(|e| LLMError::ProviderError(format!("Failed to parse config.json: {e}")))?;
167
168 let model_config = ModelConfig {
169 vocab_size: config_json
170 .get("vocab_size")
171 .and_then(|v| v.as_u64())
172 .unwrap_or(32000) as u32,
173 max_position_embeddings: config_json
174 .get("max_position_embeddings")
175 .and_then(|v| v.as_u64())
176 .unwrap_or(2048) as u32,
177 bos_token_id: config_json
178 .get("bos_token_id")
179 .and_then(|v| v.as_u64())
180 .map(|v| v as u32),
181 eos_token_id: config_json
182 .get("eos_token_id")
183 .and_then(|v| v.as_u64())
184 .map(|v| v as u32),
185 pad_token_id: config_json
186 .get("pad_token_id")
187 .and_then(|v| v.as_u64())
188 .map(|v| v as u32),
189 };
190
191 let chat_template = Self::load_chat_template(model_path, &config_json);
193
194 Ok(Self {
195 inference_runtime: tokio::sync::Mutex::new(inference_runtime),
196 tokenizer,
197 model_config,
198 max_tokens: max_tokens.unwrap_or(50),
199 temperature: temperature.unwrap_or(0.7),
200 top_p: top_p.unwrap_or(0.9),
201 system,
202 chat_template,
203 })
204 }
205
206 fn load_chat_template<P: AsRef<Path>>(model_path: P, config: &Value) -> Option<String> {
208 let model_path = model_path.as_ref();
209
210 let jinja_template_path = model_path.join("chat_template.jinja");
212 if jinja_template_path.exists() {
213 if let Ok(template_content) = std::fs::read_to_string(&jinja_template_path) {
214 log::debug!("Loaded chat template from chat_template.jinja");
215 return Some(template_content);
216 }
217 }
218
219 let tokenizer_path = model_path.join("tokenizer.json");
221 if tokenizer_path.exists() {
222 if let Ok(tokenizer_content) = std::fs::read_to_string(&tokenizer_path) {
223 if let Ok(tokenizer_json) = serde_json::from_str::<Value>(&tokenizer_content) {
224 if let Some(chat_template) =
225 tokenizer_json.get("chat_template").and_then(|v| v.as_str())
226 {
227 log::debug!("Loaded chat template from tokenizer.json");
228 return Some(chat_template.to_string());
229 }
230 }
231 }
232 }
233
234 if let Some(chat_template) = config.get("chat_template").and_then(|v| v.as_str()) {
236 log::debug!("Loaded chat template from config.json");
237 return Some(chat_template.to_string());
238 }
239
240 log::debug!("No chat template found");
241 None
242 }
243
244 fn format_messages(&self, messages: &[ChatMessage]) -> String {
247 let mut all_messages = Vec::new();
249
250 if let Some(system) = &self.system {
252 let has_system = messages.iter().any(|m| matches!(m.role, ChatRole::System));
253 if !has_system {
254 all_messages.push(ChatMessage {
255 role: ChatRole::System,
256 message_type: MessageType::Text,
257 content: system.clone(),
258 });
259 }
260 }
261
262 all_messages.extend_from_slice(messages);
264
265 match self.apply_jinja_template(&all_messages) {
267 Ok(formatted) => {
268 log::debug!("Using Jinja2 chat template");
269 formatted
270 }
271 Err(e) => {
272 log::error!("Chat template required but not available or failed: {e}");
273 log::error!("Please provide a chat_template.jinja file in the model directory");
274 "Error: No chat template found. Please add chat_template.jinja file to model directory.".to_string()
276 }
277 }
278 }
279
280 fn apply_jinja_template(&self, messages: &[ChatMessage]) -> Result<String, LLMError> {
282 let template_str = self
283 .chat_template
284 .as_ref()
285 .ok_or_else(|| LLMError::ProviderError("No chat template available".to_string()))?;
286
287 let mut env = Environment::new();
289
290 let template_messages: Vec<serde_json::Value> = messages
292 .iter()
293 .map(|msg| {
294 let role = match msg.role {
295 ChatRole::System => "system",
296 ChatRole::User => "user",
297 ChatRole::Assistant => "assistant",
298 ChatRole::Tool => "tool",
299 };
300
301 serde_json::json!({
302 "role": role,
303 "content": msg.content
304 })
305 })
306 .collect();
307
308 env.add_template("chat", template_str)
310 .map_err(|e| LLMError::ProviderError(format!("Failed to parse chat template: {e}")))?;
311
312 let template = env
314 .get_template("chat")
315 .map_err(|e| LLMError::ProviderError(format!("Failed to get chat template: {e}")))?;
316
317 let rendered = template
318 .render(context! {
319 messages => template_messages,
320 add_generation_prompt => true,
321 bos_token => "<s>",
322 eos_token => "</s>",
323 system_message => self.system.as_deref().unwrap_or(""),
324 })
325 .map_err(|e| LLMError::ProviderError(format!("Failed to render chat template: {e}")))?;
326
327 Ok(rendered)
328 }
329
330 async fn generate_text(
332 &self,
333 prompt: &str,
334 config: GenerationConfig,
335 ) -> Result<String, LLMError> {
336 let encoding = self
338 .tokenizer
339 .encode(prompt, true)
340 .map_err(|e| LLMError::ProviderError(format!("Tokenization failed: {e}")))?;
341 let input_tokens: Vec<i64> = encoding.get_ids().iter().map(|&x| x as i64).collect();
342
343 log::info!(
344 "Starting LLM generation with {} input tokens, max_new_tokens: {}",
345 input_tokens.len(),
346 config.max_new_tokens
347 );
348
349 let mut output_tokens = input_tokens.clone();
350 let max_length = input_tokens.len() + config.max_new_tokens as usize;
351
352 let mut runtime = self.inference_runtime.lock().await;
353
354 for step in 0..config.max_new_tokens {
355 if output_tokens.len() >= max_length {
356 log::info!("Reached max length, stopping generation");
357 break;
358 }
359
360 log::debug!("Generation step {}/{}", step + 1, config.max_new_tokens);
361
362 let seq_len = output_tokens.len();
364 let mut inference_input = InferenceInput::new();
365
366 let input_ids_json = Value::Array(
368 output_tokens
369 .iter()
370 .map(|&x| Value::Number(x.into()))
371 .collect(),
372 );
373 inference_input = inference_input.add_input("input_ids".to_string(), input_ids_json);
374
375 let attention_mask: Vec<Value> = vec![Value::Number(1.into()); seq_len];
377 inference_input = inference_input
378 .add_input("attention_mask".to_string(), Value::Array(attention_mask));
379
380 let position_ids: Vec<Value> = (0..seq_len as i64)
382 .map(|x| Value::Number(x.into()))
383 .collect();
384 inference_input =
385 inference_input.add_input("position_ids".to_string(), Value::Array(position_ids));
386
387 log::debug!("Running inference...");
389 let output = runtime
390 .infer(inference_input)
391 .map_err(|e| LLMError::ProviderError(format!("Inference failed: {e}")))?;
392 log::debug!("Inference completed");
393
394 let logits = output
396 .get_output("logits")
397 .ok_or_else(|| LLMError::ProviderError("No logits output found".to_string()))?;
398
399 let logits_array = logits.as_array().ok_or_else(|| {
401 LLMError::ProviderError("Logits output is not an array".to_string())
402 })?;
403
404 let vocab_size = self.model_config.vocab_size as usize;
406 let last_token_start = (seq_len - 1) * vocab_size;
407 let last_token_end = last_token_start + vocab_size;
408
409 if logits_array.len() < last_token_end {
410 return Err(LLMError::ProviderError("Invalid logits shape".to_string()));
411 }
412
413 let last_token_logits: Vec<f32> = logits_array[last_token_start..last_token_end]
414 .iter()
415 .map(|v| v.as_f64().unwrap_or(0.0) as f32)
416 .collect();
417
418 log::debug!(
419 "Got logits for last token, size: {}",
420 last_token_logits.len()
421 );
422
423 let next_token = if config.do_sample {
425 self.sample_token(&last_token_logits, config.temperature, config.top_p)?
426 } else {
427 last_token_logits
429 .iter()
430 .enumerate()
431 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
432 .map(|(i, _)| i as i64)
433 .ok_or_else(|| {
434 LLMError::ProviderError("Failed to find max logit".to_string())
435 })?
436 };
437
438 log::debug!("Generated token: {next_token}");
439
440 if let Some(eos_id) = self.model_config.eos_token_id {
442 if next_token == eos_id as i64 {
443 log::info!("Generated EOS token, stopping generation");
444 break;
445 }
446 }
447
448 output_tokens.push(next_token);
449 }
450
451 log::info!(
452 "Generation completed. Total tokens: {}, generated: {}",
453 output_tokens.len(),
454 output_tokens.len() - input_tokens.len()
455 );
456
457 let generated_tokens: Vec<u32> = output_tokens[input_tokens.len()..]
459 .iter()
460 .map(|&x| x as u32)
461 .collect();
462 let generated_text = self
463 .tokenizer
464 .decode(&generated_tokens, true)
465 .map_err(|e| LLMError::ProviderError(format!("Failed to decode tokens: {e}")))?;
466
467 log::info!("Generated text: {generated_text}");
468 Ok(generated_text)
469 }
470
471 fn sample_token(&self, logits: &[f32], temperature: f32, top_p: f32) -> Result<i64, LLMError> {
473 use rand::Rng;
474
475 let scaled_logits: Vec<f32> = logits.iter().map(|x| x / temperature).collect();
477
478 let max_logit = scaled_logits
480 .iter()
481 .fold(f32::NEG_INFINITY, |a, &b| a.max(b));
482 let exp_logits: Vec<f32> = scaled_logits
483 .iter()
484 .map(|x| (x - max_logit).exp())
485 .collect();
486 let sum_exp: f32 = exp_logits.iter().sum();
487 let probs: Vec<f32> = exp_logits.iter().map(|x| x / sum_exp).collect();
488
489 let mut sorted_indices: Vec<usize> = (0..probs.len()).collect();
491 sorted_indices.sort_by(|&a, &b| probs[b].partial_cmp(&probs[a]).unwrap());
492
493 let mut cumulative_prob = 0.0;
494 let mut cutoff_index = probs.len();
495
496 for (i, &idx) in sorted_indices.iter().enumerate() {
497 cumulative_prob += probs[idx];
498 if cumulative_prob >= top_p {
499 cutoff_index = i + 1;
500 break;
501 }
502 }
503
504 let mut rng = rand::rng();
506 let random_value: f32 = rng.random();
507
508 let mut cumulative = 0.0;
509 for &idx in sorted_indices.iter().take(cutoff_index) {
510 cumulative += probs[idx];
511 if random_value <= cumulative {
512 return Ok(idx as i64);
513 }
514 }
515
516 Ok(sorted_indices[0] as i64)
518 }
519}
520
521#[async_trait]
522impl ChatProvider for OnnxEdge {
523 async fn chat(
524 &self,
525 messages: &[ChatMessage],
526 _tools: Option<&[Tool]>,
527 json_schema: Option<StructuredOutputFormat>,
528 ) -> Result<Box<dyn ChatResponse>, LLMError> {
529 let mut modified_messages = messages.to_vec();
530
531 if let Some(schema) = &json_schema {
533 let default_schema = serde_json::json!({});
534 let schema_json = schema.schema.as_ref().unwrap_or(&default_schema);
535 let schema_str =
536 serde_json::to_string_pretty(schema_json).unwrap_or_else(|_| "{}".to_string());
537
538 let json_instruction = format!(
540 "You must respond with valid JSON that matches this schema: {schema_str}. Only return the JSON, no additional text.").to_string();
541
542 modified_messages.insert(
543 0,
544 ChatMessage {
545 role: ChatRole::System,
546 message_type: MessageType::Text,
547 content: json_instruction,
548 },
549 );
550 }
551
552 let prompt = self.format_messages(&modified_messages);
553
554 let generation_config = GenerationConfig {
555 max_new_tokens: self.max_tokens,
556 temperature: self.temperature,
557 top_p: self.top_p,
558 do_sample: true,
559 };
560
561 let response_text = self.generate_text(&prompt, generation_config).await?;
562 let cleaned_response = response_text.trim().to_string();
563
564 Ok(Box::new(EdgeResponse {
565 text: if cleaned_response.is_empty() {
566 "I'm here to help! What would you like to know?".to_string()
567 } else {
568 cleaned_response
569 },
570 }))
571 }
572}
573
574#[async_trait]
575impl CompletionProvider for OnnxEdge {
576 async fn complete(
577 &self,
578 req: &CompletionRequest,
579 _json_schema: Option<StructuredOutputFormat>,
580 ) -> Result<CompletionResponse, LLMError> {
581 let generation_config = GenerationConfig {
582 max_new_tokens: self.max_tokens,
583 temperature: self.temperature,
584 top_p: self.top_p,
585 do_sample: true,
586 };
587
588 let text = self.generate_text(&req.prompt, generation_config).await?;
589 Ok(CompletionResponse { text })
590 }
591}
592
593#[async_trait]
594impl EmbeddingProvider for OnnxEdge {
595 async fn embed(&self, _input: Vec<String>) -> Result<Vec<Vec<f32>>, LLMError> {
596 Err(LLMError::ProviderError(
597 "Embedding not supported by LiquidEdge backend".to_string(),
598 ))
599 }
600}
601
602#[async_trait]
603impl ModelsProvider for OnnxEdge {
604 async fn list_models(
605 &self,
606 _request: Option<&ModelListRequest>,
607 ) -> Result<Box<dyn ModelListResponse>, LLMError> {
608 Err(LLMError::ProviderError(
609 "Model listing not supported by LiquidEdge backend".to_string(),
610 ))
611 }
612}
613
614impl LLMProvider for OnnxEdge {}
615
616#[derive(Debug, Default)]
617pub struct LiquidEdgeBuilder {
618 model: Option<Box<dyn Model>>,
619 device: Option<Device>,
620 max_tokens: Option<u32>,
621 temperature: Option<f32>,
622 top_p: Option<f32>,
623 system: Option<String>,
624}
625
626impl LiquidEdgeBuilder {
627 pub fn new() -> Self {
628 Self::default()
629 }
630
631 pub fn model(mut self, model: Box<dyn Model>) -> Self {
632 self.model = Some(model);
633 self
634 }
635
636 pub fn device(mut self, device: Device) -> Self {
637 self.device = Some(device);
638 self
639 }
640
641 pub fn max_tokens(mut self, max_tokens: u32) -> Self {
642 self.max_tokens = Some(max_tokens);
643 self
644 }
645
646 pub fn temperature(mut self, temperature: f32) -> Self {
647 self.temperature = Some(temperature);
648 self
649 }
650
651 pub fn top_p(mut self, top_p: f32) -> Self {
652 self.top_p = Some(top_p);
653 self
654 }
655
656 pub fn system(mut self, system: &str) -> Self {
657 self.system = Some(system.to_string());
658 self
659 }
660
661 pub async fn build(self) -> Result<Arc<OnnxEdge>, LLMError> {
662 let liquid_edge = if let Some(model) = self.model {
663 if let Some(device) = self.device {
664 OnnxEdge::from_model_with_device(
665 model,
666 device,
667 "onnx-ort-model".to_string(),
668 self.max_tokens,
669 self.temperature,
670 self.top_p,
671 self.system,
672 )
673 .await?
674 } else {
675 OnnxEdge::from_model(
676 model,
677 "onnx-ort-model".to_string(),
678 self.max_tokens,
679 self.temperature,
680 self.top_p,
681 self.system,
682 )
683 .await?
684 }
685 } else {
686 return Err(LLMError::InvalidRequest(
687 "edge_model must be provided for LiquidEdge".to_string(),
688 ));
689 };
690
691 Ok(Arc::new(liquid_edge))
692 }
693}
694
695impl From<EdgeError> for LLMError {
697 fn from(err: EdgeError) -> Self {
698 LLMError::ProviderError(format!("LiquidEdge error: {err}"))
699 }
700}