offline_intelligence/model_runtime/
runtime_trait.rs1use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5use std::path::PathBuf;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
9pub enum ModelFormat {
10 GGUF,
12 GGML,
14 ONNX,
16 TensorRT,
18 Safetensors,
20 CoreML,
22}
23
24impl ModelFormat {
25 pub fn extensions(&self) -> &[&str] {
27 match self {
28 ModelFormat::GGUF => &["gguf"],
29 ModelFormat::GGML => &["ggml", "bin"],
30 ModelFormat::ONNX => &["onnx"],
31 ModelFormat::TensorRT => &["trt", "engine", "plan"],
32 ModelFormat::Safetensors => &["safetensors"],
33 ModelFormat::CoreML => &["mlmodel", "mlpackage"],
34 }
35 }
36
37 pub fn name(&self) -> &str {
39 match self {
40 ModelFormat::GGUF => "GGUF (llama.cpp)",
41 ModelFormat::GGML => "GGML (llama.cpp legacy)",
42 ModelFormat::ONNX => "ONNX Runtime",
43 ModelFormat::TensorRT => "TensorRT",
44 ModelFormat::Safetensors => "Safetensors",
45 ModelFormat::CoreML => "CoreML",
46 }
47 }
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct RuntimeConfig {
53 pub model_path: PathBuf,
55 pub format: ModelFormat,
57 pub host: String,
59 pub port: u16,
61 pub context_size: u32,
63 pub batch_size: u32,
65 pub threads: u32,
67 pub gpu_layers: u32,
69 pub parallel_slots: u32,
72 pub ubatch_size: u32,
75 pub runtime_binary: Option<PathBuf>,
77 pub draft_model_path: Option<PathBuf>,
80 pub speculative_draft_max: u32,
83 pub speculative_draft_p_min: f32,
86 pub extra_config: serde_json::Value,
88}
89
90impl Default for RuntimeConfig {
91 fn default() -> Self {
92 Self {
93 model_path: PathBuf::new(),
94 format: ModelFormat::GGUF,
95 host: "127.0.0.1".to_string(),
96 port: 8001,
97 context_size: 8192,
98 batch_size: 128,
99 threads: 6,
100 gpu_layers: 0,
101 parallel_slots: 1,
102 ubatch_size: 512,
103 runtime_binary: None,
104 draft_model_path: None,
105 speculative_draft_max: 8,
106 speculative_draft_p_min: 0.4,
107 extra_config: serde_json::json!({}),
108 }
109 }
110}
111
112#[derive(Debug, Clone, Serialize, Deserialize)]
114pub struct InferenceRequest {
115 pub messages: Vec<ChatMessage>,
116 #[serde(default = "default_max_tokens")]
117 pub max_tokens: u32,
118 #[serde(default = "default_temperature")]
119 pub temperature: f32,
120 #[serde(default = "default_stream")]
121 pub stream: bool,
122}
123
124fn default_max_tokens() -> u32 { 2000 }
125fn default_temperature() -> f32 { 0.7 }
126fn default_stream() -> bool { false }
127
128#[derive(Debug, Clone, Serialize, Deserialize)]
129pub struct ChatMessage {
130 pub role: String,
131 pub content: String,
132}
133
134#[derive(Debug, Clone, Serialize, Deserialize)]
136pub struct InferenceResponse {
137 pub content: String,
138 pub finish_reason: Option<String>,
139}
140
141#[async_trait]
143pub trait ModelRuntime: Send + Sync {
144 fn supported_format(&self) -> ModelFormat;
146
147 async fn initialize(&mut self, config: RuntimeConfig) -> anyhow::Result<()>;
149
150 async fn is_ready(&self) -> bool;
152
153 async fn health_check(&self) -> anyhow::Result<String>;
155
156 fn base_url(&self) -> String;
158
159 fn completions_url(&self) -> String {
161 format!("{}/v1/chat/completions", self.base_url())
162 }
163
164 async fn generate(
166 &self,
167 request: InferenceRequest,
168 ) -> anyhow::Result<InferenceResponse>;
169
170 async fn generate_stream(
172 &self,
173 request: InferenceRequest,
174 ) -> anyhow::Result<Box<dyn futures_util::Stream<Item = Result<String, anyhow::Error>> + Send + Unpin>>;
175
176 async fn shutdown(&mut self) -> anyhow::Result<()>;
178
179 fn metadata(&self) -> RuntimeMetadata;
181}
182
183#[derive(Debug, Clone, Serialize, Deserialize)]
185pub struct RuntimeMetadata {
186 pub format: ModelFormat,
187 pub runtime_name: String,
188 pub version: String,
189 pub supports_gpu: bool,
190 pub supports_streaming: bool,
191}