use crate::{data::ExampleData, exceptions::LangExtractResult};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
pub const EXTRACTIONS_KEY: &str = "extractions";
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ConstraintType {
None,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct Constraint {
pub constraint_type: ConstraintType,
}
impl Default for Constraint {
fn default() -> Self {
Self {
constraint_type: ConstraintType::None,
}
}
}
impl Constraint {
pub fn none() -> Self {
Self::default()
}
}
pub trait BaseSchema: Send + Sync {
fn from_examples(
examples_data: &[ExampleData],
attribute_suffix: &str,
) -> LangExtractResult<Box<dyn BaseSchema>>
where
Self: Sized;
fn to_provider_config(&self) -> HashMap<String, serde_json::Value>;
fn supports_strict_mode(&self) -> bool;
fn sync_with_provider_kwargs(&mut self, kwargs: &HashMap<String, serde_json::Value>) {
let _ = kwargs;
}
fn clone_box(&self) -> Box<dyn BaseSchema>;
}
#[derive(Debug, Clone)]
pub struct FormatModeSchema {
format: String,
}
impl FormatModeSchema {
pub fn new(format_mode: &str) -> Self {
Self {
format: format_mode.to_string(),
}
}
pub fn format(&self) -> &str {
&self.format
}
pub fn set_format(&mut self, format: String) {
self.format = format;
}
}
impl BaseSchema for FormatModeSchema {
fn from_examples(
_examples_data: &[ExampleData],
_attribute_suffix: &str,
) -> LangExtractResult<Box<dyn BaseSchema>> {
Ok(Box::new(Self::new("json")))
}
fn to_provider_config(&self) -> HashMap<String, serde_json::Value> {
let mut config = HashMap::new();
config.insert("format".to_string(), serde_json::json!(self.format));
config
}
fn supports_strict_mode(&self) -> bool {
self.format == "json"
}
fn sync_with_provider_kwargs(&mut self, kwargs: &HashMap<String, serde_json::Value>) {
if let Some(format_value) = kwargs.get("format") {
if let Some(format_str) = format_value.as_str() {
self.format = format_str.to_string();
}
}
}
fn clone_box(&self) -> Box<dyn BaseSchema> {
Box::new(self.clone())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::data::{Extraction, ExampleData};
#[test]
fn test_constraint_creation() {
let constraint = Constraint::none();
assert_eq!(constraint.constraint_type, ConstraintType::None);
let default_constraint = Constraint::default();
assert_eq!(default_constraint.constraint_type, ConstraintType::None);
}
#[test]
fn test_format_mode_schema() {
let mut schema = FormatModeSchema::new("json");
assert_eq!(schema.format(), "json");
assert!(schema.supports_strict_mode());
schema.set_format("yaml".to_string());
assert_eq!(schema.format(), "yaml");
assert!(!schema.supports_strict_mode());
}
#[test]
fn test_format_mode_schema_provider_config() {
let schema = FormatModeSchema::new("json");
let config = schema.to_provider_config();
assert_eq!(config.get("format"), Some(&serde_json::json!("json")));
}
#[test]
fn test_format_mode_schema_sync() {
let mut schema = FormatModeSchema::new("json");
let mut kwargs = HashMap::new();
kwargs.insert("format".to_string(), serde_json::json!("yaml"));
schema.sync_with_provider_kwargs(&kwargs);
assert_eq!(schema.format(), "yaml");
assert!(!schema.supports_strict_mode());
}
#[test]
fn test_format_mode_schema_from_examples() {
let examples = vec![ExampleData::new(
"Test text".to_string(),
vec![Extraction::new("test".to_string(), "value".to_string())],
)];
let schema = FormatModeSchema::from_examples(&examples, "_attributes").unwrap();
assert!(schema.supports_strict_mode()); }
#[test]
fn test_constraint_serialization() {
let constraint = Constraint::none();
let json = serde_json::to_string(&constraint).unwrap();
let deserialized: Constraint = serde_json::from_str(&json).unwrap();
assert_eq!(constraint, deserialized);
}
}