offline_intelligence/model_runtime/
runtime_trait.rs1use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5use std::path::PathBuf;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
9pub enum ModelFormat {
10 GGUF,
12 GGML,
14 ONNX,
16 TensorRT,
18 Safetensors,
20 CoreML,
22}
23
24impl ModelFormat {
25 pub fn extensions(&self) -> &[&str] {
27 match self {
28 ModelFormat::GGUF => &["gguf"],
29 ModelFormat::GGML => &["ggml", "bin"],
30 ModelFormat::ONNX => &["onnx"],
31 ModelFormat::TensorRT => &["trt", "engine", "plan"],
32 ModelFormat::Safetensors => &["safetensors"],
33 ModelFormat::CoreML => &["mlmodel", "mlpackage"],
34 }
35 }
36
37 pub fn name(&self) -> &str {
39 match self {
40 ModelFormat::GGUF => "GGUF (llama.cpp)",
41 ModelFormat::GGML => "GGML (llama.cpp legacy)",
42 ModelFormat::ONNX => "ONNX Runtime",
43 ModelFormat::TensorRT => "TensorRT",
44 ModelFormat::Safetensors => "Safetensors",
45 ModelFormat::CoreML => "CoreML",
46 }
47 }
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct RuntimeConfig {
53 pub model_path: PathBuf,
55 pub format: ModelFormat,
57 pub host: String,
59 pub port: u16,
61 pub context_size: u32,
63 pub batch_size: u32,
65 pub threads: u32,
67 pub gpu_layers: u32,
69 pub runtime_binary: Option<PathBuf>,
71 pub extra_config: serde_json::Value,
73}
74
75impl Default for RuntimeConfig {
76 fn default() -> Self {
77 Self {
78 model_path: PathBuf::new(),
79 format: ModelFormat::GGUF,
80 host: "127.0.0.1".to_string(),
81 port: 8001,
82 context_size: 8192,
83 batch_size: 128,
84 threads: 6,
85 gpu_layers: 0,
86 runtime_binary: None,
87 extra_config: serde_json::json!({}),
88 }
89 }
90}
91
92#[derive(Debug, Clone, Serialize, Deserialize)]
94pub struct InferenceRequest {
95 pub messages: Vec<ChatMessage>,
96 #[serde(default = "default_max_tokens")]
97 pub max_tokens: u32,
98 #[serde(default = "default_temperature")]
99 pub temperature: f32,
100 #[serde(default = "default_stream")]
101 pub stream: bool,
102}
103
104fn default_max_tokens() -> u32 { 2000 }
105fn default_temperature() -> f32 { 0.7 }
106fn default_stream() -> bool { false }
107
108#[derive(Debug, Clone, Serialize, Deserialize)]
109pub struct ChatMessage {
110 pub role: String,
111 pub content: String,
112}
113
114#[derive(Debug, Clone, Serialize, Deserialize)]
116pub struct InferenceResponse {
117 pub content: String,
118 pub finish_reason: Option<String>,
119}
120
121#[async_trait]
123pub trait ModelRuntime: Send + Sync {
124 fn supported_format(&self) -> ModelFormat;
126
127 async fn initialize(&mut self, config: RuntimeConfig) -> anyhow::Result<()>;
129
130 async fn is_ready(&self) -> bool;
132
133 async fn health_check(&self) -> anyhow::Result<String>;
135
136 fn base_url(&self) -> String;
138
139 fn completions_url(&self) -> String {
141 format!("{}/v1/chat/completions", self.base_url())
142 }
143
144 async fn generate(
146 &self,
147 request: InferenceRequest,
148 ) -> anyhow::Result<InferenceResponse>;
149
150 async fn generate_stream(
152 &self,
153 request: InferenceRequest,
154 ) -> anyhow::Result<Box<dyn futures_util::Stream<Item = Result<String, anyhow::Error>> + Send + Unpin>>;
155
156 async fn shutdown(&mut self) -> anyhow::Result<()>;
158
159 fn metadata(&self) -> RuntimeMetadata;
161}
162
163#[derive(Debug, Clone, Serialize, Deserialize)]
165pub struct RuntimeMetadata {
166 pub format: ModelFormat,
167 pub runtime_name: String,
168 pub version: String,
169 pub supports_gpu: bool,
170 pub supports_streaming: bool,
171}