Skip to main content

spn_native/inference/
traits.rs

1//! Inference backend traits.
2//!
3//! Defines the interface for local model inference backends.
4
5use futures_util::Stream;
6use spn_core::{ChatOptions, ChatResponse, LoadConfig, ModelInfo};
7use std::future::Future;
8use std::path::PathBuf;
9use std::pin::Pin;
10
11use crate::NativeError;
12
13/// Trait for any inference backend (mistral.rs, llama.cpp, etc.).
14///
15/// This trait provides a unified interface for loading and running
16/// local LLM inference. Implementations can use different backends
17/// while presenting the same API to consumers.
18pub trait InferenceBackend: Send + Sync {
19    /// Load a model from disk.
20    ///
21    /// # Arguments
22    /// * `model_path` - Path to the GGUF model file
23    /// * `config` - Load configuration (context size, GPU layers, etc.)
24    ///
25    /// # Returns
26    /// `Ok(())` if the model was loaded successfully.
27    fn load(
28        &mut self,
29        model_path: PathBuf,
30        config: LoadConfig,
31    ) -> impl Future<Output = Result<(), NativeError>> + Send;
32
33    /// Unload the model from memory.
34    ///
35    /// Frees GPU/CPU memory used by the model.
36    fn unload(&mut self) -> impl Future<Output = Result<(), NativeError>> + Send;
37
38    /// Check if a model is currently loaded.
39    #[must_use]
40    fn is_loaded(&self) -> bool;
41
42    /// Get metadata about the loaded model.
43    ///
44    /// Returns `None` if no model is loaded.
45    fn model_info(&self) -> Option<&ModelInfo>;
46
47    /// Generate a response (non-streaming).
48    ///
49    /// # Arguments
50    /// * `prompt` - The input prompt
51    /// * `options` - Generation options (temperature, max_tokens, etc.)
52    ///
53    /// # Returns
54    /// The complete chat response.
55    fn infer(
56        &self,
57        prompt: &str,
58        options: ChatOptions,
59    ) -> impl Future<Output = Result<ChatResponse, NativeError>> + Send;
60
61    /// Generate a response (streaming).
62    ///
63    /// Returns a stream of token strings as they are generated.
64    ///
65    /// # Arguments
66    /// * `prompt` - The input prompt
67    /// * `options` - Generation options (temperature, max_tokens, etc.)
68    fn infer_stream(
69        &self,
70        prompt: &str,
71        options: ChatOptions,
72    ) -> impl Future<Output = Result<impl Stream<Item = Result<String, NativeError>> + Send, NativeError>>
73           + Send;
74}
75
76/// Object-safe version of InferenceBackend for dynamic dispatch.
77///
78/// Use this when you need runtime polymorphism (e.g., `Box<dyn DynInferenceBackend>`).
79///
80/// Note: This trait takes owned `String` instead of `&str` for prompts
81/// to enable object-safe async methods.
82#[allow(clippy::type_complexity)]
83pub trait DynInferenceBackend: Send + Sync {
84    /// Load a model from disk (boxed future for object safety).
85    fn load_dyn(
86        &mut self,
87        model_path: PathBuf,
88        config: LoadConfig,
89    ) -> Pin<Box<dyn Future<Output = Result<(), NativeError>> + Send + '_>>;
90
91    /// Unload the model from memory (boxed future for object safety).
92    fn unload_dyn(&mut self) -> Pin<Box<dyn Future<Output = Result<(), NativeError>> + Send + '_>>;
93
94    /// Check if a model is currently loaded.
95    fn is_loaded_dyn(&self) -> bool;
96
97    /// Get metadata about the loaded model (cloned for object safety).
98    fn model_info_dyn(&self) -> Option<ModelInfo>;
99
100    /// Generate a response (boxed future for object safety).
101    ///
102    /// Takes owned `String` instead of `&str` for object safety.
103    fn infer_dyn(
104        &self,
105        prompt: String,
106        options: ChatOptions,
107    ) -> Pin<Box<dyn Future<Output = Result<ChatResponse, NativeError>> + Send + '_>>;
108
109    /// Generate a streaming response (boxed stream for object safety).
110    ///
111    /// Takes owned `String` instead of `&str` for object safety.
112    fn infer_stream_dyn(
113        &self,
114        prompt: String,
115        options: ChatOptions,
116    ) -> Pin<
117        Box<
118            dyn Future<
119                    Output = Result<
120                        Pin<Box<dyn Stream<Item = Result<String, NativeError>> + Send + 'static>>,
121                        NativeError,
122                    >,
123                > + Send
124                + '_,
125        >,
126    >;
127}
128
129/// Blanket implementation of DynInferenceBackend for any InferenceBackend.
130impl<T: InferenceBackend + 'static> DynInferenceBackend for T {
131    fn load_dyn(
132        &mut self,
133        model_path: PathBuf,
134        config: LoadConfig,
135    ) -> Pin<Box<dyn Future<Output = Result<(), NativeError>> + Send + '_>> {
136        Box::pin(self.load(model_path, config))
137    }
138
139    fn unload_dyn(&mut self) -> Pin<Box<dyn Future<Output = Result<(), NativeError>> + Send + '_>> {
140        Box::pin(self.unload())
141    }
142
143    fn is_loaded_dyn(&self) -> bool {
144        InferenceBackend::is_loaded(self)
145    }
146
147    fn model_info_dyn(&self) -> Option<ModelInfo> {
148        InferenceBackend::model_info(self).cloned()
149    }
150
151    fn infer_dyn(
152        &self,
153        prompt: String,
154        options: ChatOptions,
155    ) -> Pin<Box<dyn Future<Output = Result<ChatResponse, NativeError>> + Send + '_>> {
156        Box::pin(async move { self.infer(&prompt, options).await })
157    }
158
159    fn infer_stream_dyn(
160        &self,
161        _prompt: String,
162        _options: ChatOptions,
163    ) -> Pin<
164        Box<
165            dyn Future<
166                    Output = Result<
167                        Pin<Box<dyn Stream<Item = Result<String, NativeError>> + Send + 'static>>,
168                        NativeError,
169                    >,
170                > + Send
171                + '_,
172        >,
173    > {
174        Box::pin(async move {
175            // We cannot easily box a stream that borrows from self,
176            // so for streaming, callers should use InferenceBackend directly
177            // or collect results into a Vec first
178            Err(NativeError::InvalidConfig(
179                "Streaming not supported via DynInferenceBackend. Use InferenceBackend directly."
180                    .to_string(),
181            ))
182        })
183    }
184}
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189
190    // Verify the trait is object-safe
191    fn _assert_object_safe(_: &dyn DynInferenceBackend) {}
192}