Skip to main content

offline_intelligence/model_runtime/
runtime_trait.rs

1//! Core trait and types for model runtime abstraction
2
3use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5use std::path::PathBuf;
6
7/// Supported model formats
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
9pub enum ModelFormat {
10    /// GGUF format (llama.cpp quantized)
11    GGUF,
12    /// GGML format (llama.cpp legacy)
13    GGML,
14    /// ONNX format (Open Neural Network Exchange)
15    ONNX,
16    /// TensorRT optimized format (NVIDIA)
17    TensorRT,
18    /// Safetensors format (Hugging Face)
19    Safetensors,
20    /// CoreML format (Apple)
21    CoreML,
22}
23
24impl ModelFormat {
25    /// Get file extensions for this format
26    pub fn extensions(&self) -> &[&str] {
27        match self {
28            ModelFormat::GGUF => &["gguf"],
29            ModelFormat::GGML => &["ggml", "bin"],
30            ModelFormat::ONNX => &["onnx"],
31            ModelFormat::TensorRT => &["trt", "engine", "plan"],
32            ModelFormat::Safetensors => &["safetensors"],
33            ModelFormat::CoreML => &["mlmodel", "mlpackage"],
34        }
35    }
36
37    /// Get human-readable name
38    pub fn name(&self) -> &str {
39        match self {
40            ModelFormat::GGUF => "GGUF (llama.cpp)",
41            ModelFormat::GGML => "GGML (llama.cpp legacy)",
42            ModelFormat::ONNX => "ONNX Runtime",
43            ModelFormat::TensorRT => "TensorRT",
44            ModelFormat::Safetensors => "Safetensors",
45            ModelFormat::CoreML => "CoreML",
46        }
47    }
48}
49
50/// Runtime configuration for model initialization
51#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct RuntimeConfig {
53    /// Path to model file
54    pub model_path: PathBuf,
55    /// Model format
56    pub format: ModelFormat,
57    /// Host for runtime server (e.g., "127.0.0.1")
58    pub host: String,
59    /// Port for runtime server (e.g., 8001)
60    pub port: u16,
61    /// Context size
62    pub context_size: u32,
63    /// Batch size
64    pub batch_size: u32,
65    /// Number of CPU threads
66    pub threads: u32,
67    /// GPU layers to offload (0 = CPU only)
68    pub gpu_layers: u32,
69    /// Path to runtime binary (e.g., llama-server.exe)
70    pub runtime_binary: Option<PathBuf>,
71    /// Additional runtime-specific configuration
72    pub extra_config: serde_json::Value,
73}
74
75impl Default for RuntimeConfig {
76    fn default() -> Self {
77        Self {
78            model_path: PathBuf::new(),
79            format: ModelFormat::GGUF,
80            host: "127.0.0.1".to_string(),
81            port: 8001,
82            context_size: 8192,
83            batch_size: 128,
84            threads: 6,
85            gpu_layers: 0,
86            runtime_binary: None,
87            extra_config: serde_json::json!({}),
88        }
89    }
90}
91
92/// Inference request (OpenAI-compatible format)
93#[derive(Debug, Clone, Serialize, Deserialize)]
94pub struct InferenceRequest {
95    pub messages: Vec<ChatMessage>,
96    #[serde(default = "default_max_tokens")]
97    pub max_tokens: u32,
98    #[serde(default = "default_temperature")]
99    pub temperature: f32,
100    #[serde(default = "default_stream")]
101    pub stream: bool,
102}
103
104fn default_max_tokens() -> u32 { 2000 }
105fn default_temperature() -> f32 { 0.7 }
106fn default_stream() -> bool { false }
107
108#[derive(Debug, Clone, Serialize, Deserialize)]
109pub struct ChatMessage {
110    pub role: String,
111    pub content: String,
112}
113
114/// Inference response
115#[derive(Debug, Clone, Serialize, Deserialize)]
116pub struct InferenceResponse {
117    pub content: String,
118    pub finish_reason: Option<String>,
119}
120
121/// Model runtime trait - all runtime adapters must implement this
122#[async_trait]
123pub trait ModelRuntime: Send + Sync {
124    /// Get the format this runtime supports
125    fn supported_format(&self) -> ModelFormat;
126
127    /// Initialize the runtime (start server process, load model, etc.)
128    async fn initialize(&mut self, config: RuntimeConfig) -> anyhow::Result<()>;
129
130    /// Check if runtime is ready for inference
131    async fn is_ready(&self) -> bool;
132
133    /// Get health status
134    async fn health_check(&self) -> anyhow::Result<String>;
135
136    /// Get the base URL for inference API (e.g., "http://127.0.0.1:8001")
137    fn base_url(&self) -> String;
138
139    /// Get the OpenAI-compatible chat completions endpoint
140    fn completions_url(&self) -> String {
141        format!("{}/v1/chat/completions", self.base_url())
142    }
143
144    /// Perform inference (non-streaming)
145    async fn generate(
146        &self,
147        request: InferenceRequest,
148    ) -> anyhow::Result<InferenceResponse>;
149
150    /// Perform streaming inference
151    async fn generate_stream(
152        &self,
153        request: InferenceRequest,
154    ) -> anyhow::Result<Box<dyn futures_util::Stream<Item = Result<String, anyhow::Error>> + Send + Unpin>>;
155
156    /// Shutdown the runtime (stop server, cleanup resources)
157    async fn shutdown(&mut self) -> anyhow::Result<()>;
158
159    /// Get runtime metadata
160    fn metadata(&self) -> RuntimeMetadata;
161}
162
163/// Runtime metadata
164#[derive(Debug, Clone, Serialize, Deserialize)]
165pub struct RuntimeMetadata {
166    pub format: ModelFormat,
167    pub runtime_name: String,
168    pub version: String,
169    pub supports_gpu: bool,
170    pub supports_streaming: bool,
171}