Skip to main content

spn_native/inference/
runtime.rs

1//! Native runtime implementation using mistral.rs.
2//!
3//! This module provides the `NativeRuntime` struct which implements
4//! the `InferenceBackend` trait using the mistral.rs library.
5
6use crate::inference::traits::InferenceBackend;
7use crate::NativeError;
8use futures_util::stream::Stream;
9use spn_core::{ChatOptions, ChatResponse, LoadConfig, ModelInfo};
10use std::path::PathBuf;
11
12#[cfg(feature = "inference")]
13use spn_core::ChatRole;
14#[cfg(feature = "inference")]
15use std::path::Path;
16#[cfg(feature = "inference")]
17use std::sync::Arc;
18#[cfg(feature = "inference")]
19use tracing::{debug, info};
20
21#[cfg(feature = "inference")]
22use mistralrs::{
23    GgufModelBuilder, MemoryGpuConfig, Model, PagedAttentionMetaBuilder, RequestBuilder,
24    TextMessageRole, TextMessages,
25};
26#[cfg(feature = "inference")]
27use tokio::sync::RwLock;
28
29/// Native runtime for local LLM inference.
30///
31/// Uses mistral.rs for high-performance inference on GGUF models.
32/// Supports CPU and GPU (Metal on macOS, CUDA on Linux) acceleration.
33///
34/// # Example
35///
36/// ```ignore
37/// use spn_native::inference::NativeRuntime;
38/// use spn_core::LoadConfig;
39///
40/// let mut runtime = NativeRuntime::new()?;
41/// runtime.load("model.gguf".into(), LoadConfig::default()).await?;
42/// let response = runtime.infer("Hello!", Default::default()).await?;
43/// ```
44#[allow(dead_code)] // Fields used only with inference feature
45pub struct NativeRuntime {
46    /// The loaded model (None if no model is loaded).
47    #[cfg(feature = "inference")]
48    model: Option<Arc<RwLock<Model>>>,
49
50    /// Metadata about the loaded model.
51    model_info: Option<ModelInfo>,
52
53    /// Path to the currently loaded model.
54    model_path: Option<PathBuf>,
55
56    /// Load configuration used for the current model.
57    config: Option<LoadConfig>,
58}
59
60impl NativeRuntime {
61    /// Create a new native runtime.
62    ///
63    /// The runtime is created without a model loaded. Call `load()` to
64    /// load a model before running inference.
65    #[must_use]
66    pub fn new() -> Self {
67        Self {
68            #[cfg(feature = "inference")]
69            model: None,
70            model_info: None,
71            model_path: None,
72            config: None,
73        }
74    }
75
76    /// Get the path to the currently loaded model.
77    #[must_use]
78    pub fn model_path(&self) -> Option<&PathBuf> {
79        self.model_path.as_ref()
80    }
81
82    /// Get the load configuration for the current model.
83    #[must_use]
84    pub fn config(&self) -> Option<&LoadConfig> {
85        self.config.as_ref()
86    }
87
88    /// Convert spn-core ChatRole to mistral.rs TextMessageRole.
89    #[cfg(feature = "inference")]
90    #[allow(dead_code)] // Will be used for streaming support
91    fn convert_role(role: ChatRole) -> TextMessageRole {
92        match role {
93            ChatRole::System => TextMessageRole::System,
94            ChatRole::User => TextMessageRole::User,
95            ChatRole::Assistant => TextMessageRole::Assistant,
96        }
97    }
98}
99
100impl Default for NativeRuntime {
101    fn default() -> Self {
102        Self::new()
103    }
104}
105
106#[cfg(feature = "inference")]
107impl InferenceBackend for NativeRuntime {
108    async fn load(&mut self, model_path: PathBuf, config: LoadConfig) -> Result<(), NativeError> {
109        info!(?model_path, "Loading GGUF model");
110
111        // Unload any existing model
112        if self.model.is_some() {
113            self.unload().await?;
114        }
115
116        // Validate path exists
117        if !model_path.exists() {
118            return Err(NativeError::ModelNotFound {
119                repo: "local".to_string(),
120                filename: model_path.to_string_lossy().to_string(),
121            });
122        }
123
124        // Build the model using GgufModelBuilder
125        // API: GgufModelBuilder::new(directory, vec![filename])
126        let parent = model_path
127            .parent()
128            .map(|p| p.to_string_lossy().to_string())
129            .unwrap_or_else(|| ".".to_string());
130        let filename = model_path
131            .file_name()
132            .map(|f| f.to_string_lossy().to_string())
133            .ok_or_else(|| {
134                NativeError::InvalidConfig("Invalid model path: no filename".to_string())
135            })?;
136
137        debug!(gpu_layers = config.gpu_layers, %parent, %filename, "Building model");
138
139        // Build model with PagedAttention for better memory management.
140        // PagedAttention enables efficient KV cache handling for longer contexts.
141        // Use context_size from LoadConfig, defaulting to 2048 if not specified.
142        let context_size = config.context_size.unwrap_or(2048);
143        let model = GgufModelBuilder::new(parent, vec![filename])
144            .with_logging()
145            .with_paged_attn(|| {
146                PagedAttentionMetaBuilder::default()
147                    .with_block_size(32)
148                    .with_gpu_memory(MemoryGpuConfig::ContextSize(context_size as usize))
149                    .build()
150            })
151            .map_err(|e| NativeError::InvalidConfig(format!("PagedAttention config error: {e}")))?
152            .build()
153            .await
154            .map_err(|e| NativeError::InvalidConfig(format!("Failed to build model: {e}")))?;
155
156        // Extract model info from the loaded model
157        let info = ModelInfo {
158            name: model_path
159                .file_stem()
160                .map(|s| s.to_string_lossy().to_string())
161                .unwrap_or_else(|| "unknown".to_string()),
162            size: tokio::fs::metadata(&model_path)
163                .await
164                .map(|m| m.len())
165                .unwrap_or(0),
166            quantization: extract_quantization_from_path(&model_path),
167            parameters: None,
168            digest: None,
169        };
170
171        self.model = Some(Arc::new(RwLock::new(model)));
172        self.model_info = Some(info);
173        self.model_path = Some(model_path);
174        self.config = Some(config);
175
176        info!("Model loaded successfully");
177        Ok(())
178    }
179
180    async fn unload(&mut self) -> Result<(), NativeError> {
181        if self.model.is_some() {
182            info!("Unloading model");
183            self.model = None;
184            self.model_info = None;
185            self.model_path = None;
186            self.config = None;
187        }
188        Ok(())
189    }
190
191    fn is_loaded(&self) -> bool {
192        self.model.is_some()
193    }
194
195    fn model_info(&self) -> Option<&ModelInfo> {
196        self.model_info.as_ref()
197    }
198
199    async fn infer(&self, prompt: &str, options: ChatOptions) -> Result<ChatResponse, NativeError> {
200        let model = self.model.as_ref().ok_or(NativeError::ModelNotLoaded)?;
201
202        let model = model.read().await;
203
204        // Build messages - just user prompt for now
205        // System messages should be passed as part of the prompt or via messages API
206        let messages = TextMessages::new().add_message(TextMessageRole::User, prompt);
207
208        debug!(
209            temperature = options.temperature,
210            max_tokens = options.max_tokens,
211            "Running inference"
212        );
213
214        // Build request with sampling parameters
215        let mut request = RequestBuilder::from(messages);
216
217        // Apply temperature if provided (convert f32 to f64)
218        if let Some(temp) = options.temperature {
219            request = request.set_sampler_temperature(f64::from(temp));
220        }
221
222        // Apply max_tokens if provided
223        if let Some(max_tokens) = options.max_tokens {
224            request = request.set_sampler_max_len(max_tokens as usize);
225        }
226
227        // Send request with sampling parameters
228        let response = model
229            .send_chat_request(request)
230            .await
231            .map_err(|e| NativeError::InvalidConfig(format!("Inference failed: {e}")))?;
232
233        // Extract response content - fail if no content returned
234        let content = response
235            .choices
236            .first()
237            .and_then(|c| c.message.content.clone())
238            .ok_or_else(|| {
239                NativeError::InvalidConfig("Model returned empty response (no choices)".to_string())
240            })?;
241
242        // Log performance metrics for debugging and optimization
243        debug!(
244            prompt_tokens = response.usage.prompt_tokens,
245            completion_tokens = response.usage.completion_tokens,
246            avg_prompt_tok_per_sec = ?response.usage.avg_prompt_tok_per_sec,
247            avg_compl_tok_per_sec = ?response.usage.avg_compl_tok_per_sec,
248            "Inference completed"
249        );
250
251        Ok(ChatResponse {
252            message: spn_core::ChatMessage {
253                role: ChatRole::Assistant,
254                content,
255            },
256            done: true,
257            total_duration: None,
258            prompt_eval_count: Some(response.usage.prompt_tokens as u32),
259            eval_count: Some(response.usage.completion_tokens as u32),
260        })
261    }
262
263    async fn infer_stream(
264        &self,
265        _prompt: &str,
266        _options: ChatOptions,
267    ) -> Result<impl Stream<Item = Result<String, NativeError>> + Send, NativeError> {
268        // Streaming requires complex lifetime management with the model lock.
269        // For now, use the non-streaming `infer` method instead.
270        // TODO: Implement streaming by cloning the Arc and managing lifetimes properly.
271        Err::<futures_util::stream::Empty<Result<String, NativeError>>, _>(
272            NativeError::InvalidConfig(
273                "Streaming not yet implemented for native runtime. Use infer() instead.".to_string(),
274            ),
275        )
276    }
277}
278
279/// Extract quantization from file path.
280///
281/// Delegates to [`crate::extract_quantization`] for the actual parsing.
282#[cfg(feature = "inference")]
283fn extract_quantization_from_path(path: &Path) -> Option<String> {
284    let filename = path.file_name()?.to_string_lossy();
285    crate::extract_quantization(&filename)
286}
287
288// Stub implementation when inference feature is not enabled
289#[cfg(not(feature = "inference"))]
290impl InferenceBackend for NativeRuntime {
291    async fn load(&mut self, _model_path: PathBuf, _config: LoadConfig) -> Result<(), NativeError> {
292        Err(NativeError::InvalidConfig(
293            "Inference feature not enabled. Rebuild with --features inference".to_string(),
294        ))
295    }
296
297    async fn unload(&mut self) -> Result<(), NativeError> {
298        Ok(())
299    }
300
301    fn is_loaded(&self) -> bool {
302        false
303    }
304
305    fn model_info(&self) -> Option<&ModelInfo> {
306        None
307    }
308
309    async fn infer(
310        &self,
311        _prompt: &str,
312        _options: ChatOptions,
313    ) -> Result<ChatResponse, NativeError> {
314        Err(NativeError::InvalidConfig(
315            "Inference feature not enabled. Rebuild with --features inference".to_string(),
316        ))
317    }
318
319    async fn infer_stream(
320        &self,
321        _prompt: &str,
322        _options: ChatOptions,
323    ) -> Result<impl Stream<Item = Result<String, NativeError>> + Send, NativeError> {
324        Err::<futures_util::stream::Empty<Result<String, NativeError>>, _>(
325            NativeError::InvalidConfig(
326                "Inference feature not enabled. Rebuild with --features inference".to_string(),
327            ),
328        )
329    }
330}
331
332#[cfg(test)]
333mod tests {
334    use super::*;
335
336    #[test]
337    fn test_runtime_creation() {
338        let runtime = NativeRuntime::new();
339        assert!(!runtime.is_loaded());
340        assert!(runtime.model_info().is_none());
341        assert!(runtime.model_path().is_none());
342    }
343
344    #[test]
345    fn test_runtime_default() {
346        let runtime = NativeRuntime::default();
347        assert!(!runtime.is_loaded());
348    }
349
350    #[tokio::test]
351    #[cfg(not(feature = "inference"))]
352    async fn test_load_without_feature() {
353        let mut runtime = NativeRuntime::new();
354        let result = runtime
355            .load(PathBuf::from("test.gguf"), LoadConfig::default())
356            .await;
357        assert!(result.is_err());
358        assert!(result
359            .unwrap_err()
360            .to_string()
361            .contains("Inference feature not enabled"));
362    }
363
364    #[tokio::test]
365    #[cfg(not(feature = "inference"))]
366    async fn test_infer_without_feature() {
367        let runtime = NativeRuntime::new();
368        let result = runtime.infer("test", ChatOptions::default()).await;
369        assert!(result.is_err());
370    }
371}