langextract_rust/
schema.rs1use crate::{data::ExampleData, exceptions::LangExtractResult};
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6
7pub const EXTRACTIONS_KEY: &str = "extractions";
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
12#[serde(rename_all = "snake_case")]
13pub enum ConstraintType {
14 None,
15}
16
17#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
19pub struct Constraint {
20 pub constraint_type: ConstraintType,
22}
23
24impl Default for Constraint {
25 fn default() -> Self {
26 Self {
27 constraint_type: ConstraintType::None,
28 }
29 }
30}
31
32impl Constraint {
33 pub fn none() -> Self {
35 Self::default()
36 }
37}
38
39pub trait BaseSchema: Send + Sync {
41 fn from_examples(
43 examples_data: &[ExampleData],
44 attribute_suffix: &str,
45 ) -> LangExtractResult<Box<dyn BaseSchema>>
46 where
47 Self: Sized;
48
49 fn to_provider_config(&self) -> HashMap<String, serde_json::Value>;
54
55 fn supports_strict_mode(&self) -> bool;
61
62 fn sync_with_provider_kwargs(&mut self, kwargs: &HashMap<String, serde_json::Value>) {
68 let _ = kwargs;
70 }
71
72 fn clone_box(&self) -> Box<dyn BaseSchema>;
74}
75
76#[derive(Debug, Clone)]
82pub struct FormatModeSchema {
83 format: String,
84}
85
86impl FormatModeSchema {
87 pub fn new(format_mode: &str) -> Self {
89 Self {
90 format: format_mode.to_string(),
91 }
92 }
93
94 pub fn format(&self) -> &str {
96 &self.format
97 }
98
99 pub fn set_format(&mut self, format: String) {
101 self.format = format;
102 }
103}
104
105impl BaseSchema for FormatModeSchema {
106 fn from_examples(
107 _examples_data: &[ExampleData],
108 _attribute_suffix: &str,
109 ) -> LangExtractResult<Box<dyn BaseSchema>> {
110 Ok(Box::new(Self::new("json")))
113 }
114
115 fn to_provider_config(&self) -> HashMap<String, serde_json::Value> {
116 let mut config = HashMap::new();
117 config.insert("format".to_string(), serde_json::json!(self.format));
118 config
119 }
120
121 fn supports_strict_mode(&self) -> bool {
122 self.format == "json"
124 }
125
126 fn sync_with_provider_kwargs(&mut self, kwargs: &HashMap<String, serde_json::Value>) {
127 if let Some(format_value) = kwargs.get("format") {
128 if let Some(format_str) = format_value.as_str() {
129 self.format = format_str.to_string();
130 }
131 }
132 }
133
134 fn clone_box(&self) -> Box<dyn BaseSchema> {
135 Box::new(self.clone())
136 }
137}
138
139#[cfg(test)]
140mod tests {
141 use super::*;
142 use crate::data::{Extraction, ExampleData};
143
144 #[test]
145 fn test_constraint_creation() {
146 let constraint = Constraint::none();
147 assert_eq!(constraint.constraint_type, ConstraintType::None);
148
149 let default_constraint = Constraint::default();
150 assert_eq!(default_constraint.constraint_type, ConstraintType::None);
151 }
152
153 #[test]
154 fn test_format_mode_schema() {
155 let mut schema = FormatModeSchema::new("json");
156 assert_eq!(schema.format(), "json");
157 assert!(schema.supports_strict_mode());
158
159 schema.set_format("yaml".to_string());
160 assert_eq!(schema.format(), "yaml");
161 assert!(!schema.supports_strict_mode());
162 }
163
164 #[test]
165 fn test_format_mode_schema_provider_config() {
166 let schema = FormatModeSchema::new("json");
167 let config = schema.to_provider_config();
168 assert_eq!(config.get("format"), Some(&serde_json::json!("json")));
169 }
170
171 #[test]
172 fn test_format_mode_schema_sync() {
173 let mut schema = FormatModeSchema::new("json");
174
175 let mut kwargs = HashMap::new();
176 kwargs.insert("format".to_string(), serde_json::json!("yaml"));
177
178 schema.sync_with_provider_kwargs(&kwargs);
179 assert_eq!(schema.format(), "yaml");
180 assert!(!schema.supports_strict_mode());
181 }
182
183 #[test]
184 fn test_format_mode_schema_from_examples() {
185 let examples = vec![ExampleData::new(
186 "Test text".to_string(),
187 vec![Extraction::new("test".to_string(), "value".to_string())],
188 )];
189
190 let schema = FormatModeSchema::from_examples(&examples, "_attributes").unwrap();
191 assert!(schema.supports_strict_mode()); }
193
194 #[test]
195 fn test_constraint_serialization() {
196 let constraint = Constraint::none();
197 let json = serde_json::to_string(&constraint).unwrap();
198 let deserialized: Constraint = serde_json::from_str(&json).unwrap();
199 assert_eq!(constraint, deserialized);
200 }
201}