use crate::{data::FormatType, exceptions::LangExtractResult, schema::BaseSchema};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::fmt;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ScoredOutput {
pub score: Option<f32>,
pub output: Option<String>,
}
impl ScoredOutput {
pub fn new(output: String, score: Option<f32>) -> Self {
Self {
output: Some(output),
score,
}
}
pub fn from_text(output: String) -> Self {
Self {
output: Some(output),
score: None,
}
}
pub fn text(&self) -> &str {
self.output.as_deref().unwrap_or("")
}
pub fn has_score(&self) -> bool {
self.score.is_some()
}
}
impl fmt::Display for ScoredOutput {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let score_str = match self.score {
Some(score) => format!("{:.2}", score),
None => "-".to_string(),
};
match &self.output {
Some(output) => {
writeln!(f, "Score: {}", score_str)?;
writeln!(f, "Output:")?;
for line in output.lines() {
writeln!(f, " {}", line)?;
}
Ok(())
}
None => write!(f, "Score: {}\nOutput: None", score_str),
}
}
}
#[async_trait]
pub trait BaseLanguageModel: Send + Sync {
fn get_schema_class(&self) -> Option<Box<dyn BaseSchema>> {
None
}
fn apply_schema(&mut self, _schema: Option<Box<dyn BaseSchema>>) {
}
fn set_fence_output(&mut self, _fence_output: Option<bool>) {
}
fn requires_fence_output(&self) -> bool {
true }
async fn infer(
&self,
batch_prompts: &[String],
kwargs: &std::collections::HashMap<String, serde_json::Value>,
) -> LangExtractResult<Vec<Vec<ScoredOutput>>>;
async fn infer_single(
&self,
prompt: &str,
kwargs: &std::collections::HashMap<String, serde_json::Value>,
) -> LangExtractResult<Vec<ScoredOutput>> {
let results = self.infer(&[prompt.to_string()], kwargs).await?;
Ok(results.into_iter().next().unwrap_or_default())
}
fn parse_output(&self, output: &str) -> LangExtractResult<serde_json::Value> {
match serde_json::from_str(output) {
Ok(value) => Ok(value),
Err(_) => {
match serde_yaml::from_str::<serde_yaml::Value>(output) {
Ok(value) => {
let json_str = serde_json::to_string(&value)?;
Ok(serde_json::from_str(&json_str)?)
}
Err(e) => Err(crate::exceptions::LangExtractError::parsing(format!(
"Failed to parse output as JSON or YAML: {}",
e
))),
}
}
}
}
fn format_type(&self) -> FormatType {
FormatType::Json }
fn model_id(&self) -> &str;
fn provider_name(&self) -> &str;
fn supported_models() -> Vec<&'static str>
where
Self: Sized,
{
vec![]
}
fn supports_model(model_id: &str) -> bool
where
Self: Sized,
{
Self::supported_models()
.iter()
.any(|&supported| model_id.contains(supported))
}
}
#[derive(Debug, thiserror::Error)]
#[error("No scored outputs available from the language model: {message}")]
pub struct InferenceOutputError {
pub message: String,
}
impl InferenceOutputError {
pub fn new(message: String) -> Self {
Self { message }
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InferenceConfig {
pub temperature: f32,
pub max_tokens: Option<usize>,
pub num_candidates: usize,
pub stop_sequences: Vec<String>,
pub extra_params: std::collections::HashMap<String, serde_json::Value>,
}
impl Default for InferenceConfig {
fn default() -> Self {
Self {
temperature: 0.5,
max_tokens: None,
num_candidates: 1,
stop_sequences: vec![],
extra_params: std::collections::HashMap::new(),
}
}
}
impl InferenceConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.temperature = temperature.clamp(0.0, 1.0);
self
}
pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
self.max_tokens = Some(max_tokens);
self
}
pub fn with_num_candidates(mut self, num_candidates: usize) -> Self {
self.num_candidates = num_candidates.max(1);
self
}
pub fn with_stop_sequence(mut self, stop_sequence: String) -> Self {
self.stop_sequences.push(stop_sequence);
self
}
pub fn with_extra_param(mut self, key: String, value: serde_json::Value) -> Self {
self.extra_params.insert(key, value);
self
}
pub fn to_hashmap(&self) -> std::collections::HashMap<String, serde_json::Value> {
let mut map = std::collections::HashMap::new();
map.insert("temperature".to_string(), serde_json::json!(self.temperature));
if let Some(max_tokens) = self.max_tokens {
map.insert("max_tokens".to_string(), serde_json::json!(max_tokens));
}
map.insert("num_candidates".to_string(), serde_json::json!(self.num_candidates));
if !self.stop_sequences.is_empty() {
map.insert("stop_sequences".to_string(), serde_json::json!(self.stop_sequences));
}
for (key, value) in &self.extra_params {
map.insert(key.clone(), value.clone());
}
map
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_scored_output_creation() {
let output = ScoredOutput::new("Hello world".to_string(), Some(0.9));
assert_eq!(output.text(), "Hello world");
assert!(output.has_score());
assert_eq!(output.score, Some(0.9));
let output_no_score = ScoredOutput::from_text("Hello world".to_string());
assert_eq!(output_no_score.text(), "Hello world");
assert!(!output_no_score.has_score());
}
#[test]
fn test_scored_output_display() {
let output = ScoredOutput::new("Hello\nworld".to_string(), Some(0.85));
let display = format!("{}", output);
assert!(display.contains("Score: 0.85"));
assert!(display.contains(" Hello"));
assert!(display.contains(" world"));
let output_no_score = ScoredOutput::from_text("Test".to_string());
let display = format!("{}", output_no_score);
assert!(display.contains("Score: -"));
}
#[test]
fn test_inference_config() {
let config = InferenceConfig::new()
.with_temperature(0.7)
.with_max_tokens(100)
.with_num_candidates(3)
.with_stop_sequence("END".to_string())
.with_extra_param("custom_param".to_string(), serde_json::json!("value"));
assert_eq!(config.temperature, 0.7);
assert_eq!(config.max_tokens, Some(100));
assert_eq!(config.num_candidates, 3);
assert_eq!(config.stop_sequences, vec!["END"]);
let hashmap = config.to_hashmap();
assert_eq!(hashmap.get("temperature"), Some(&serde_json::json!(0.7f32)));
assert_eq!(hashmap.get("max_tokens"), Some(&serde_json::json!(100)));
assert_eq!(hashmap.get("custom_param"), Some(&serde_json::json!("value")));
}
#[test]
fn test_temperature_clamping() {
let config = InferenceConfig::new().with_temperature(1.5);
assert_eq!(config.temperature, 1.0);
let config = InferenceConfig::new().with_temperature(-0.5);
assert_eq!(config.temperature, 0.0);
}
#[test]
fn test_serialization() {
let output = ScoredOutput::new("test".to_string(), Some(0.5));
let json = serde_json::to_string(&output).unwrap();
let deserialized: ScoredOutput = serde_json::from_str(&json).unwrap();
assert_eq!(output, deserialized);
let config = InferenceConfig::new().with_temperature(0.8);
let json = serde_json::to_string(&config).unwrap();
let deserialized: InferenceConfig = serde_json::from_str(&json).unwrap();
assert_eq!(config.temperature, deserialized.temperature);
}
}