1use crate::{data::FormatType, exceptions::LangExtractResult, schema::BaseSchema};
7use async_trait::async_trait;
8use serde::{Deserialize, Serialize};
9use std::fmt;
10
11#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
13pub struct ScoredOutput {
14 pub score: Option<f32>,
16 pub output: Option<String>,
18}
19
20impl ScoredOutput {
21 pub fn new(output: String, score: Option<f32>) -> Self {
23 Self {
24 output: Some(output),
25 score,
26 }
27 }
28
29 pub fn from_text(output: String) -> Self {
31 Self {
32 output: Some(output),
33 score: None,
34 }
35 }
36
37 pub fn text(&self) -> &str {
39 self.output.as_deref().unwrap_or("")
40 }
41
42 pub fn has_score(&self) -> bool {
44 self.score.is_some()
45 }
46}
47
48impl fmt::Display for ScoredOutput {
49 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50 let score_str = match self.score {
51 Some(score) => format!("{:.2}", score),
52 None => "-".to_string(),
53 };
54
55 match &self.output {
56 Some(output) => {
57 writeln!(f, "Score: {}", score_str)?;
58 writeln!(f, "Output:")?;
59 for line in output.lines() {
60 writeln!(f, " {}", line)?;
61 }
62 Ok(())
63 }
64 None => write!(f, "Score: {}\nOutput: None", score_str),
65 }
66 }
67}
68
69#[async_trait]
74pub trait BaseLanguageModel: Send + Sync {
75 fn get_schema_class(&self) -> Option<Box<dyn BaseSchema>> {
77 None
78 }
79
80 fn apply_schema(&mut self, _schema: Option<Box<dyn BaseSchema>>) {
82 }
84
85 fn set_fence_output(&mut self, _fence_output: Option<bool>) {
87 }
89
90 fn requires_fence_output(&self) -> bool {
92 true }
94
95 async fn infer(
106 &self,
107 batch_prompts: &[String],
108 kwargs: &std::collections::HashMap<String, serde_json::Value>,
109 ) -> LangExtractResult<Vec<Vec<ScoredOutput>>>;
110
111 async fn infer_single(
113 &self,
114 prompt: &str,
115 kwargs: &std::collections::HashMap<String, serde_json::Value>,
116 ) -> LangExtractResult<Vec<ScoredOutput>> {
117 let results = self.infer(&[prompt.to_string()], kwargs).await?;
118 Ok(results.into_iter().next().unwrap_or_default())
119 }
120
121 fn parse_output(&self, output: &str) -> LangExtractResult<serde_json::Value> {
126 match serde_json::from_str(output) {
128 Ok(value) => Ok(value),
129 Err(_) => {
130 match serde_yaml::from_str::<serde_yaml::Value>(output) {
132 Ok(value) => {
133 let json_str = serde_json::to_string(&value)?;
135 Ok(serde_json::from_str(&json_str)?)
136 }
137 Err(e) => Err(crate::exceptions::LangExtractError::parsing(format!(
138 "Failed to parse output as JSON or YAML: {}",
139 e
140 ))),
141 }
142 }
143 }
144 }
145
146 fn format_type(&self) -> FormatType {
148 FormatType::Json }
150
151 fn model_id(&self) -> &str;
153
154 fn provider_name(&self) -> &str;
156
157 fn supported_models() -> Vec<&'static str>
159 where
160 Self: Sized,
161 {
162 vec![]
163 }
164
165 fn supports_model(model_id: &str) -> bool
167 where
168 Self: Sized,
169 {
170 Self::supported_models()
171 .iter()
172 .any(|&supported| model_id.contains(supported))
173 }
174}
175
176#[derive(Debug, thiserror::Error)]
178#[error("No scored outputs available from the language model: {message}")]
179pub struct InferenceOutputError {
180 pub message: String,
181}
182
183impl InferenceOutputError {
184 pub fn new(message: String) -> Self {
185 Self { message }
186 }
187}
188
189#[derive(Debug, Clone, Serialize, Deserialize)]
191pub struct InferenceConfig {
192 pub temperature: f32,
194 pub max_tokens: Option<usize>,
196 pub num_candidates: usize,
198 pub stop_sequences: Vec<String>,
200 pub extra_params: std::collections::HashMap<String, serde_json::Value>,
202}
203
204impl Default for InferenceConfig {
205 fn default() -> Self {
206 Self {
207 temperature: 0.5,
208 max_tokens: None,
209 num_candidates: 1,
210 stop_sequences: vec![],
211 extra_params: std::collections::HashMap::new(),
212 }
213 }
214}
215
216impl InferenceConfig {
217 pub fn new() -> Self {
219 Self::default()
220 }
221
222 pub fn with_temperature(mut self, temperature: f32) -> Self {
224 self.temperature = temperature.clamp(0.0, 1.0);
225 self
226 }
227
228 pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
230 self.max_tokens = Some(max_tokens);
231 self
232 }
233
234 pub fn with_num_candidates(mut self, num_candidates: usize) -> Self {
236 self.num_candidates = num_candidates.max(1);
237 self
238 }
239
240 pub fn with_stop_sequence(mut self, stop_sequence: String) -> Self {
242 self.stop_sequences.push(stop_sequence);
243 self
244 }
245
246 pub fn with_extra_param(mut self, key: String, value: serde_json::Value) -> Self {
248 self.extra_params.insert(key, value);
249 self
250 }
251
252 pub fn to_hashmap(&self) -> std::collections::HashMap<String, serde_json::Value> {
254 let mut map = std::collections::HashMap::new();
255 map.insert("temperature".to_string(), serde_json::json!(self.temperature));
256
257 if let Some(max_tokens) = self.max_tokens {
258 map.insert("max_tokens".to_string(), serde_json::json!(max_tokens));
259 }
260
261 map.insert("num_candidates".to_string(), serde_json::json!(self.num_candidates));
262
263 if !self.stop_sequences.is_empty() {
264 map.insert("stop_sequences".to_string(), serde_json::json!(self.stop_sequences));
265 }
266
267 for (key, value) in &self.extra_params {
269 map.insert(key.clone(), value.clone());
270 }
271
272 map
273 }
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279
280 #[test]
281 fn test_scored_output_creation() {
282 let output = ScoredOutput::new("Hello world".to_string(), Some(0.9));
283 assert_eq!(output.text(), "Hello world");
284 assert!(output.has_score());
285 assert_eq!(output.score, Some(0.9));
286
287 let output_no_score = ScoredOutput::from_text("Hello world".to_string());
288 assert_eq!(output_no_score.text(), "Hello world");
289 assert!(!output_no_score.has_score());
290 }
291
292 #[test]
293 fn test_scored_output_display() {
294 let output = ScoredOutput::new("Hello\nworld".to_string(), Some(0.85));
295 let display = format!("{}", output);
296 assert!(display.contains("Score: 0.85"));
297 assert!(display.contains(" Hello"));
298 assert!(display.contains(" world"));
299
300 let output_no_score = ScoredOutput::from_text("Test".to_string());
301 let display = format!("{}", output_no_score);
302 assert!(display.contains("Score: -"));
303 }
304
305 #[test]
306 fn test_inference_config() {
307 let config = InferenceConfig::new()
308 .with_temperature(0.7)
309 .with_max_tokens(100)
310 .with_num_candidates(3)
311 .with_stop_sequence("END".to_string())
312 .with_extra_param("custom_param".to_string(), serde_json::json!("value"));
313
314 assert_eq!(config.temperature, 0.7);
315 assert_eq!(config.max_tokens, Some(100));
316 assert_eq!(config.num_candidates, 3);
317 assert_eq!(config.stop_sequences, vec!["END"]);
318
319 let hashmap = config.to_hashmap();
320 assert_eq!(hashmap.get("temperature"), Some(&serde_json::json!(0.7f32)));
321 assert_eq!(hashmap.get("max_tokens"), Some(&serde_json::json!(100)));
322 assert_eq!(hashmap.get("custom_param"), Some(&serde_json::json!("value")));
323 }
324
325 #[test]
326 fn test_temperature_clamping() {
327 let config = InferenceConfig::new().with_temperature(1.5);
328 assert_eq!(config.temperature, 1.0);
329
330 let config = InferenceConfig::new().with_temperature(-0.5);
331 assert_eq!(config.temperature, 0.0);
332 }
333
334 #[test]
335 fn test_serialization() {
336 let output = ScoredOutput::new("test".to_string(), Some(0.5));
337 let json = serde_json::to_string(&output).unwrap();
338 let deserialized: ScoredOutput = serde_json::from_str(&json).unwrap();
339 assert_eq!(output, deserialized);
340
341 let config = InferenceConfig::new().with_temperature(0.8);
342 let json = serde_json::to_string(&config).unwrap();
343 let deserialized: InferenceConfig = serde_json::from_str(&json).unwrap();
344 assert_eq!(config.temperature, deserialized.temperature);
345 }
346}