1use serde::{Deserialize, Serialize};
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct Slot {
10 pub name: String,
12
13 pub prompt: String,
15
16 pub kind: SlotKind,
18
19 pub constraints: Option<SlotConstraints>,
21
22 pub required: bool,
24
25 pub default: Option<String>,
27
28 pub temperature: Option<f32>,
30}
31
32#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
34#[serde(rename_all = "snake_case")]
35pub enum SlotKind {
36 #[default]
38 Raw,
39
40 Function,
42
43 Class,
45
46 Html,
48
49 Css,
51
52 JavaScript,
54
55 Component,
57
58 Custom(String),
60}
61
62#[derive(Debug, Clone, Default, Serialize, Deserialize)]
64pub struct SlotConstraints {
65 pub max_lines: Option<usize>,
67
68 pub max_chars: Option<usize>,
70
71 pub required_imports: Vec<String>,
73
74 pub forbidden_patterns: Vec<String>,
76
77 pub language: Option<String>,
79}
80
81impl Slot {
82 pub fn new(name: impl Into<String>, prompt: impl Into<String>) -> Self {
98 Self {
99 name: name.into(),
100 prompt: prompt.into(),
101 kind: SlotKind::default(),
102 constraints: None,
103 required: true,
104 default: None,
105 temperature: None,
106 }
107 }
108
109 pub fn with_temperature(mut self, temp: f32) -> Self {
111 self.temperature = Some(temp.clamp(0.0, 2.0));
112 self
113 }
114
115 pub fn with_kind(mut self, kind: SlotKind) -> Self {
117 self.kind = kind;
118 self
119 }
120
121 pub fn with_constraints(mut self, constraints: SlotConstraints) -> Self {
123 self.constraints = Some(constraints);
124 self
125 }
126
127 pub fn optional(mut self, default: impl Into<String>) -> Self {
129 self.required = false;
130 self.default = Some(default.into());
131 self
132 }
133
134 pub fn validate(&self, code: &str) -> Result<(), Vec<String>> {
136 let mut errors = Vec::new();
137
138 if let Some(ref constraints) = self.constraints {
139 if let Some(max) = constraints.max_lines {
141 let lines = code.lines().count();
142 if lines > max {
143 errors.push(format!("Code exceeds max lines: {} > {}", lines, max));
144 }
145 }
146
147 if let Some(max) = constraints.max_chars {
149 if code.len() > max {
150 errors.push(format!("Code exceeds max chars: {} > {}", code.len(), max));
151 }
152 }
153
154 for pattern in &constraints.forbidden_patterns {
156 if let Ok(re) = regex::Regex::new(pattern) {
157 if re.is_match(code) {
158 errors.push(format!("Code contains forbidden pattern: {}", pattern));
159 }
160 }
161 }
162 }
163
164 if errors.is_empty() {
165 Ok(())
166 } else {
167 Err(errors)
168 }
169 }
170}
171
172impl SlotConstraints {
173 pub fn new() -> Self {
175 Self::default()
176 }
177
178 pub fn max_lines(mut self, lines: usize) -> Self {
180 self.max_lines = Some(lines);
181 self
182 }
183
184 pub fn max_chars(mut self, chars: usize) -> Self {
186 self.max_chars = Some(chars);
187 self
188 }
189
190 pub fn language(mut self, lang: impl Into<String>) -> Self {
192 self.language = Some(lang.into());
193 self
194 }
195
196 pub fn require_import(mut self, import: impl Into<String>) -> Self {
198 self.required_imports.push(import.into());
199 self
200 }
201
202 pub fn forbid_pattern(mut self, pattern: impl Into<String>) -> Self {
204 self.forbidden_patterns.push(pattern.into());
205 self
206 }
207}
208
209#[cfg(test)]
210mod tests {
211 use super::*;
212
213 #[test]
214 fn test_slot_creation() {
215 let slot = Slot::new("test", "Generate a test");
216 assert_eq!(slot.name, "test");
217 assert_eq!(slot.prompt, "Generate a test");
218 assert!(slot.required);
219 }
220
221 #[test]
222 fn test_slot_validation() {
223 let slot = Slot::new("test", "")
224 .with_constraints(SlotConstraints::new().max_lines(5));
225
226 assert!(slot.validate("line1\nline2\nline3").is_ok());
227 assert!(slot.validate("1\n2\n3\n4\n5\n6").is_err());
228 }
229}