Skip to main content

offline_intelligence/model_runtime/
runtime_trait.rs

1//! Core trait and types for model runtime abstraction
2
3use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5use std::path::PathBuf;
6
7/// Supported model formats
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
9pub enum ModelFormat {
10    /// GGUF format (llama.cpp quantized)
11    GGUF,
12    /// GGML format (llama.cpp legacy)
13    GGML,
14    /// ONNX format (Open Neural Network Exchange)
15    ONNX,
16    /// TensorRT optimized format (NVIDIA)
17    TensorRT,
18    /// Safetensors format (Hugging Face)
19    Safetensors,
20    /// CoreML format (Apple)
21    CoreML,
22}
23
24impl ModelFormat {
25    /// Get file extensions for this format
26    pub fn extensions(&self) -> &[&str] {
27        match self {
28            ModelFormat::GGUF => &["gguf"],
29            ModelFormat::GGML => &["ggml", "bin"],
30            ModelFormat::ONNX => &["onnx"],
31            ModelFormat::TensorRT => &["trt", "engine", "plan"],
32            ModelFormat::Safetensors => &["safetensors"],
33            ModelFormat::CoreML => &["mlmodel", "mlpackage"],
34        }
35    }
36
37    /// Get human-readable name
38    pub fn name(&self) -> &str {
39        match self {
40            ModelFormat::GGUF => "GGUF (llama.cpp)",
41            ModelFormat::GGML => "GGML (llama.cpp legacy)",
42            ModelFormat::ONNX => "ONNX Runtime",
43            ModelFormat::TensorRT => "TensorRT",
44            ModelFormat::Safetensors => "Safetensors",
45            ModelFormat::CoreML => "CoreML",
46        }
47    }
48}
49
50/// Runtime configuration for model initialization
51#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct RuntimeConfig {
53    /// Path to model file
54    pub model_path: PathBuf,
55    /// Model format
56    pub format: ModelFormat,
57    /// Host for runtime server (e.g., "127.0.0.1")
58    pub host: String,
59    /// Port for runtime server (e.g., 8001)
60    pub port: u16,
61    /// Context size
62    pub context_size: u32,
63    /// Batch size
64    pub batch_size: u32,
65    /// Number of CPU threads
66    pub threads: u32,
67    /// GPU layers to offload (0 = CPU only)
68    pub gpu_layers: u32,
69    /// Number of parallel KV-cache slots (continuous batching slots).
70    /// Should match MAX_CONCURRENT_STREAMS. Maps to llama-server --parallel N.
71    pub parallel_slots: u32,
72    /// Micro-batch size for GPU compute. Larger values increase tensor-core
73    /// utilisation. Maps to llama-server --ubatch-size N.
74    pub ubatch_size: u32,
75    /// Path to runtime binary (e.g., llama-server.exe)
76    pub runtime_binary: Option<PathBuf>,
77    /// Path to draft model for speculative decoding. None = disabled.
78    /// Maps to llama-server --model-draft.
79    pub draft_model_path: Option<PathBuf>,
80    /// Maximum draft tokens generated per speculative step.
81    /// Maps to llama-server --draft-max. Default: 8.
82    pub speculative_draft_max: u32,
83    /// Minimum acceptance probability for a draft token.
84    /// Maps to llama-server --draft-p-min. Default: 0.4.
85    pub speculative_draft_p_min: f32,
86    /// Additional runtime-specific configuration
87    pub extra_config: serde_json::Value,
88}
89
90impl Default for RuntimeConfig {
91    fn default() -> Self {
92        Self {
93            model_path: PathBuf::new(),
94            format: ModelFormat::GGUF,
95            host: "127.0.0.1".to_string(),
96            port: 8001,
97            context_size: 8192,
98            batch_size: 128,
99            threads: 6,
100            gpu_layers: 0,
101            parallel_slots: 1,
102            ubatch_size: 512,
103            runtime_binary: None,
104            draft_model_path: None,
105            speculative_draft_max: 8,
106            speculative_draft_p_min: 0.4,
107            extra_config: serde_json::json!({}),
108        }
109    }
110}
111
112/// Inference request (OpenAI-compatible format)
113#[derive(Debug, Clone, Serialize, Deserialize)]
114pub struct InferenceRequest {
115    pub messages: Vec<ChatMessage>,
116    #[serde(default = "default_max_tokens")]
117    pub max_tokens: u32,
118    #[serde(default = "default_temperature")]
119    pub temperature: f32,
120    #[serde(default = "default_stream")]
121    pub stream: bool,
122}
123
124fn default_max_tokens() -> u32 { 2000 }
125fn default_temperature() -> f32 { 0.7 }
126fn default_stream() -> bool { false }
127
128#[derive(Debug, Clone, Serialize, Deserialize)]
129pub struct ChatMessage {
130    pub role: String,
131    pub content: String,
132}
133
134/// Inference response
135#[derive(Debug, Clone, Serialize, Deserialize)]
136pub struct InferenceResponse {
137    pub content: String,
138    pub finish_reason: Option<String>,
139}
140
141/// Model runtime trait - all runtime adapters must implement this
142#[async_trait]
143pub trait ModelRuntime: Send + Sync {
144    /// Get the format this runtime supports
145    fn supported_format(&self) -> ModelFormat;
146
147    /// Initialize the runtime (start server process, load model, etc.)
148    async fn initialize(&mut self, config: RuntimeConfig) -> anyhow::Result<()>;
149
150    /// Check if runtime is ready for inference
151    async fn is_ready(&self) -> bool;
152
153    /// Get health status
154    async fn health_check(&self) -> anyhow::Result<String>;
155
156    /// Get the base URL for inference API (e.g., "http://127.0.0.1:8001")
157    fn base_url(&self) -> String;
158
159    /// Get the OpenAI-compatible chat completions endpoint
160    fn completions_url(&self) -> String {
161        format!("{}/v1/chat/completions", self.base_url())
162    }
163
164    /// Perform inference (non-streaming)
165    async fn generate(
166        &self,
167        request: InferenceRequest,
168    ) -> anyhow::Result<InferenceResponse>;
169
170    /// Perform streaming inference
171    async fn generate_stream(
172        &self,
173        request: InferenceRequest,
174    ) -> anyhow::Result<Box<dyn futures_util::Stream<Item = Result<String, anyhow::Error>> + Send + Unpin>>;
175
176    /// Shutdown the runtime (stop server, cleanup resources)
177    async fn shutdown(&mut self) -> anyhow::Result<()>;
178
179    /// Get runtime metadata
180    fn metadata(&self) -> RuntimeMetadata;
181}
182
183/// Runtime metadata
184#[derive(Debug, Clone, Serialize, Deserialize)]
185pub struct RuntimeMetadata {
186    pub format: ModelFormat,
187    pub runtime_name: String,
188    pub version: String,
189    pub supports_gpu: bool,
190    pub supports_streaming: bool,
191}