langextract_rust/
inference.rs

1//! Language model inference abstractions and implementations.
2//!
3//! This module provides the core abstraction for language model inference,
4//! including the base trait that all providers must implement.
5
6use crate::{data::FormatType, exceptions::LangExtractResult, schema::BaseSchema};
7use async_trait::async_trait;
8use serde::{Deserialize, Serialize};
9use std::fmt;
10
11/// A scored output from a language model
12#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
13pub struct ScoredOutput {
14    /// Confidence score for this output (if available)
15    pub score: Option<f32>,
16    /// The generated text output
17    pub output: Option<String>,
18}
19
20impl ScoredOutput {
21    /// Create a new scored output
22    pub fn new(output: String, score: Option<f32>) -> Self {
23        Self {
24            output: Some(output),
25            score,
26        }
27    }
28
29    /// Create a scored output with just text (no score)
30    pub fn from_text(output: String) -> Self {
31        Self {
32            output: Some(output),
33            score: None,
34        }
35    }
36
37    /// Get the output text, returning empty string if None
38    pub fn text(&self) -> &str {
39        self.output.as_deref().unwrap_or("")
40    }
41
42    /// Check if this output has a score
43    pub fn has_score(&self) -> bool {
44        self.score.is_some()
45    }
46}
47
48impl fmt::Display for ScoredOutput {
49    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50        let score_str = match self.score {
51            Some(score) => format!("{:.2}", score),
52            None => "-".to_string(),
53        };
54
55        match &self.output {
56            Some(output) => {
57                writeln!(f, "Score: {}", score_str)?;
58                writeln!(f, "Output:")?;
59                for line in output.lines() {
60                    writeln!(f, "  {}", line)?;
61                }
62                Ok(())
63            }
64            None => write!(f, "Score: {}\nOutput: None", score_str),
65        }
66    }
67}
68
69/// Abstract base trait for language model inference
70///
71/// All language model providers must implement this trait to be compatible
72/// with the langextract framework.
73#[async_trait]
74pub trait BaseLanguageModel: Send + Sync {
75    /// Get the schema class this provider supports
76    fn get_schema_class(&self) -> Option<Box<dyn BaseSchema>> {
77        None
78    }
79
80    /// Apply a schema instance to this provider
81    fn apply_schema(&mut self, _schema: Option<Box<dyn BaseSchema>>) {
82        // Default implementation does nothing
83    }
84
85    /// Set explicit fence output preference
86    fn set_fence_output(&mut self, _fence_output: Option<bool>) {
87        // Default implementation does nothing
88    }
89
90    /// Whether this model requires fence output for parsing
91    fn requires_fence_output(&self) -> bool {
92        true // Conservative default
93    }
94
95    /// Perform inference on a batch of prompts
96    ///
97    /// # Arguments
98    ///
99    /// * `batch_prompts` - Batch of input prompts for inference
100    /// * `kwargs` - Additional inference parameters (temperature, max_tokens, etc.)
101    ///
102    /// # Returns
103    ///
104    /// A vector of batches, where each batch contains scored outputs for one prompt
105    async fn infer(
106        &self,
107        batch_prompts: &[String],
108        kwargs: &std::collections::HashMap<String, serde_json::Value>,
109    ) -> LangExtractResult<Vec<Vec<ScoredOutput>>>;
110
111    /// Convenience method for single prompt inference
112    async fn infer_single(
113        &self,
114        prompt: &str,
115        kwargs: &std::collections::HashMap<String, serde_json::Value>,
116    ) -> LangExtractResult<Vec<ScoredOutput>> {
117        let results = self.infer(&[prompt.to_string()], kwargs).await?;
118        Ok(results.into_iter().next().unwrap_or_default())
119    }
120
121    /// Parse model output as JSON or YAML
122    ///
123    /// This expects raw JSON/YAML without code fences.
124    /// Code fence extraction is handled by the resolver.
125    fn parse_output(&self, output: &str) -> LangExtractResult<serde_json::Value> {
126        // Default implementation tries JSON first, then YAML
127        match serde_json::from_str(output) {
128            Ok(value) => Ok(value),
129            Err(_) => {
130                // Try YAML if JSON fails
131                match serde_yaml::from_str::<serde_yaml::Value>(output) {
132                    Ok(value) => {
133                        // Convert YAML value to JSON value for consistency
134                        let json_str = serde_json::to_string(&value)?;
135                        Ok(serde_json::from_str(&json_str)?)
136                    }
137                    Err(e) => Err(crate::exceptions::LangExtractError::parsing(format!(
138                        "Failed to parse output as JSON or YAML: {}",
139                        e
140                    ))),
141                }
142            }
143        }
144    }
145
146    /// Get the format type this model uses
147    fn format_type(&self) -> FormatType {
148        FormatType::Json // Default to JSON
149    }
150
151    /// Get the model ID/name
152    fn model_id(&self) -> &str;
153
154    /// Get the provider name
155    fn provider_name(&self) -> &str;
156
157    /// Get supported model IDs for this provider
158    fn supported_models() -> Vec<&'static str>
159    where
160        Self: Sized,
161    {
162        vec![]
163    }
164
165    /// Check if this provider supports a given model ID
166    fn supports_model(model_id: &str) -> bool
167    where
168        Self: Sized,
169    {
170        Self::supported_models()
171            .iter()
172            .any(|&supported| model_id.contains(supported))
173    }
174}
175
176/// Error type for inference operations that don't produce any outputs
177#[derive(Debug, thiserror::Error)]
178#[error("No scored outputs available from the language model: {message}")]
179pub struct InferenceOutputError {
180    pub message: String,
181}
182
183impl InferenceOutputError {
184    pub fn new(message: String) -> Self {
185        Self { message }
186    }
187}
188
189/// Inference configuration parameters
190#[derive(Debug, Clone, Serialize, Deserialize)]
191pub struct InferenceConfig {
192    /// Sampling temperature (0.0 to 1.0)
193    pub temperature: f32,
194    /// Maximum number of tokens to generate
195    pub max_tokens: Option<usize>,
196    /// Number of candidate outputs to generate
197    pub num_candidates: usize,
198    /// Stop sequences to halt generation
199    pub stop_sequences: Vec<String>,
200    /// Additional provider-specific parameters
201    pub extra_params: std::collections::HashMap<String, serde_json::Value>,
202}
203
204impl Default for InferenceConfig {
205    fn default() -> Self {
206        Self {
207            temperature: 0.5,
208            max_tokens: None,
209            num_candidates: 1,
210            stop_sequences: vec![],
211            extra_params: std::collections::HashMap::new(),
212        }
213    }
214}
215
216impl InferenceConfig {
217    /// Create a new inference config with default values
218    pub fn new() -> Self {
219        Self::default()
220    }
221
222    /// Set the temperature
223    pub fn with_temperature(mut self, temperature: f32) -> Self {
224        self.temperature = temperature.clamp(0.0, 1.0);
225        self
226    }
227
228    /// Set the maximum number of tokens
229    pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
230        self.max_tokens = Some(max_tokens);
231        self
232    }
233
234    /// Set the number of candidate outputs
235    pub fn with_num_candidates(mut self, num_candidates: usize) -> Self {
236        self.num_candidates = num_candidates.max(1);
237        self
238    }
239
240    /// Add a stop sequence
241    pub fn with_stop_sequence(mut self, stop_sequence: String) -> Self {
242        self.stop_sequences.push(stop_sequence);
243        self
244    }
245
246    /// Add an extra parameter
247    pub fn with_extra_param(mut self, key: String, value: serde_json::Value) -> Self {
248        self.extra_params.insert(key, value);
249        self
250    }
251
252    /// Convert to a HashMap for passing to inference methods
253    pub fn to_hashmap(&self) -> std::collections::HashMap<String, serde_json::Value> {
254        let mut map = std::collections::HashMap::new();
255        map.insert("temperature".to_string(), serde_json::json!(self.temperature));
256        
257        if let Some(max_tokens) = self.max_tokens {
258            map.insert("max_tokens".to_string(), serde_json::json!(max_tokens));
259        }
260        
261        map.insert("num_candidates".to_string(), serde_json::json!(self.num_candidates));
262        
263        if !self.stop_sequences.is_empty() {
264            map.insert("stop_sequences".to_string(), serde_json::json!(self.stop_sequences));
265        }
266
267        // Add extra parameters
268        for (key, value) in &self.extra_params {
269            map.insert(key.clone(), value.clone());
270        }
271
272        map
273    }
274}
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279
280    #[test]
281    fn test_scored_output_creation() {
282        let output = ScoredOutput::new("Hello world".to_string(), Some(0.9));
283        assert_eq!(output.text(), "Hello world");
284        assert!(output.has_score());
285        assert_eq!(output.score, Some(0.9));
286
287        let output_no_score = ScoredOutput::from_text("Hello world".to_string());
288        assert_eq!(output_no_score.text(), "Hello world");
289        assert!(!output_no_score.has_score());
290    }
291
292    #[test]
293    fn test_scored_output_display() {
294        let output = ScoredOutput::new("Hello\nworld".to_string(), Some(0.85));
295        let display = format!("{}", output);
296        assert!(display.contains("Score: 0.85"));
297        assert!(display.contains("  Hello"));
298        assert!(display.contains("  world"));
299
300        let output_no_score = ScoredOutput::from_text("Test".to_string());
301        let display = format!("{}", output_no_score);
302        assert!(display.contains("Score: -"));
303    }
304
305    #[test]
306    fn test_inference_config() {
307        let config = InferenceConfig::new()
308            .with_temperature(0.7)
309            .with_max_tokens(100)
310            .with_num_candidates(3)
311            .with_stop_sequence("END".to_string())
312            .with_extra_param("custom_param".to_string(), serde_json::json!("value"));
313
314        assert_eq!(config.temperature, 0.7);
315        assert_eq!(config.max_tokens, Some(100));
316        assert_eq!(config.num_candidates, 3);
317        assert_eq!(config.stop_sequences, vec!["END"]);
318
319        let hashmap = config.to_hashmap();
320        assert_eq!(hashmap.get("temperature"), Some(&serde_json::json!(0.7f32)));
321        assert_eq!(hashmap.get("max_tokens"), Some(&serde_json::json!(100)));
322        assert_eq!(hashmap.get("custom_param"), Some(&serde_json::json!("value")));
323    }
324
325    #[test]
326    fn test_temperature_clamping() {
327        let config = InferenceConfig::new().with_temperature(1.5);
328        assert_eq!(config.temperature, 1.0);
329
330        let config = InferenceConfig::new().with_temperature(-0.5);
331        assert_eq!(config.temperature, 0.0);
332    }
333
334    #[test]
335    fn test_serialization() {
336        let output = ScoredOutput::new("test".to_string(), Some(0.5));
337        let json = serde_json::to_string(&output).unwrap();
338        let deserialized: ScoredOutput = serde_json::from_str(&json).unwrap();
339        assert_eq!(output, deserialized);
340
341        let config = InferenceConfig::new().with_temperature(0.8);
342        let json = serde_json::to_string(&config).unwrap();
343        let deserialized: InferenceConfig = serde_json::from_str(&json).unwrap();
344        assert_eq!(config.temperature, deserialized.temperature);
345    }
346}