1use crate::{
2 CoreError, CoreResult, DecodeRequest, EncodeRequest, LanguageDescriptor, LanguageRegistry,
3 ModelCapability, ModelDescriptor, ModelRegistry, PipelineOptions, StrategyDescriptor,
4 StrategyRegistry,
5};
6
7#[derive(Debug, Clone, PartialEq, Eq)]
8pub struct ValidatedEncodeRequest {
9 pub language: LanguageDescriptor,
10 pub strategy: StrategyDescriptor,
11 pub model: Option<ModelDescriptor>,
12}
13
14#[derive(Debug, Clone, PartialEq, Eq)]
15pub struct ValidatedDecodeRequest {
16 pub language: LanguageDescriptor,
17 pub strategy: StrategyDescriptor,
18 pub model: Option<ModelDescriptor>,
19}
20
21pub fn validate_encode_request(
22 request: &EncodeRequest,
23 language_registry: &dyn LanguageRegistry,
24 strategy_registry: &dyn StrategyRegistry,
25 model_registry: &dyn ModelRegistry,
26) -> CoreResult<ValidatedEncodeRequest> {
27 let validated = validate_pipeline_options(
28 &request.options,
29 language_registry,
30 strategy_registry,
31 model_registry,
32 )?;
33
34 Ok(ValidatedEncodeRequest {
35 language: validated.language,
36 strategy: validated.strategy,
37 model: validated.model,
38 })
39}
40
41pub fn validate_decode_request(
42 request: &DecodeRequest,
43 language_registry: &dyn LanguageRegistry,
44 strategy_registry: &dyn StrategyRegistry,
45 model_registry: &dyn ModelRegistry,
46) -> CoreResult<ValidatedDecodeRequest> {
47 let validated = validate_pipeline_options(
48 &request.options,
49 language_registry,
50 strategy_registry,
51 model_registry,
52 )?;
53
54 Ok(ValidatedDecodeRequest {
55 language: validated.language,
56 strategy: validated.strategy,
57 model: validated.model,
58 })
59}
60
61#[derive(Debug, Clone, PartialEq, Eq)]
62struct ValidatedPipelineOptions {
63 language: LanguageDescriptor,
64 strategy: StrategyDescriptor,
65 model: Option<ModelDescriptor>,
66}
67
68fn validate_pipeline_options(
69 options: &PipelineOptions,
70 language_registry: &dyn LanguageRegistry,
71 strategy_registry: &dyn StrategyRegistry,
72 model_registry: &dyn ModelRegistry,
73) -> CoreResult<ValidatedPipelineOptions> {
74 let language = language_registry
75 .language(&options.language)
76 .cloned()
77 .ok_or_else(|| CoreError::UnsupportedLanguage(options.language.to_string()))?;
78
79 let strategy = strategy_registry
80 .strategy(&options.strategy)
81 .cloned()
82 .ok_or_else(|| CoreError::UnsupportedStrategy(options.strategy.to_string()))?;
83
84 let model = match &options.model_selection {
85 Some(selection) => {
86 let model = model_registry
87 .model(&selection.provider, &selection.model)
88 .cloned()
89 .ok_or_else(|| CoreError::UnsupportedModel {
90 provider: selection.provider.to_string(),
91 model: selection.model.to_string(),
92 })?;
93
94 ensure_model_supports_language(&model, &options.language)?;
95 ensure_required_capabilities(&model, &strategy)?;
96
97 Some(model)
98 }
99 None => {
100 if strategy.required_capabilities.is_empty() {
101 None
102 } else {
103 return Err(CoreError::StrategyRequiresModel(strategy.id.to_string()));
104 }
105 }
106 };
107
108 Ok(ValidatedPipelineOptions {
109 language,
110 strategy,
111 model,
112 })
113}
114
115fn ensure_model_supports_language(
116 model: &ModelDescriptor,
117 language: &crate::LanguageTag,
118) -> CoreResult<()> {
119 if model.supported_languages.iter().any(|tag| tag == language) {
120 Ok(())
121 } else {
122 Err(CoreError::ModelDoesNotSupportLanguage {
123 provider: model.provider.to_string(),
124 model: model.model.to_string(),
125 language: language.to_string(),
126 })
127 }
128}
129
130fn ensure_required_capabilities(
131 model: &ModelDescriptor,
132 strategy: &StrategyDescriptor,
133) -> CoreResult<()> {
134 for capability in &strategy.required_capabilities {
135 if !model.capabilities.contains(capability) {
136 return Err(CoreError::ModelMissingCapability {
137 provider: model.provider.to_string(),
138 model: model.model.to_string(),
139 capability: capability_name(*capability),
140 });
141 }
142 }
143
144 Ok(())
145}
146
147fn capability_name(capability: ModelCapability) -> &'static str {
148 capability.as_str()
149}
150
151#[cfg(test)]
152mod tests {
153 use crate::{
154 DecodeRequest, EncodeRequest, LanguageDescriptor, LanguageRegistry, LanguageTag,
155 ModelCapability, ModelDescriptor, ModelId, ModelRegistry, ModelSelection, PipelineOptions,
156 ProviderId, StrategyDescriptor, StrategyId, StrategyRegistry, TextDirection,
157 validate_decode_request, validate_encode_request,
158 };
159
160 struct TestLanguageRegistry {
161 values: Vec<LanguageDescriptor>,
162 }
163
164 impl LanguageRegistry for TestLanguageRegistry {
165 fn all_languages(&self) -> &[LanguageDescriptor] {
166 &self.values
167 }
168 }
169
170 struct TestStrategyRegistry {
171 values: Vec<StrategyDescriptor>,
172 }
173
174 impl StrategyRegistry for TestStrategyRegistry {
175 fn all_strategies(&self) -> &[StrategyDescriptor] {
176 &self.values
177 }
178 }
179
180 struct TestModelRegistry {
181 values: Vec<ModelDescriptor>,
182 }
183
184 impl ModelRegistry for TestModelRegistry {
185 fn all_models(&self) -> &[ModelDescriptor] {
186 &self.values
187 }
188 }
189
190 #[test]
191 fn validate_encode_request_accepts_strategy_without_model_requirement() {
192 let request = EncodeRequest {
193 carrier_text: "hello".to_string(),
194 payload: vec![1, 2, 3],
195 options: PipelineOptions {
196 language: LanguageTag::new("en").expect("valid language"),
197 strategy: StrategyId::new("synonym").expect("valid strategy"),
198 model_selection: None,
199 },
200 };
201
202 let languages = TestLanguageRegistry {
203 values: vec![LanguageDescriptor {
204 tag: LanguageTag::new("en").expect("valid"),
205 display_name: "English".to_string(),
206 direction: TextDirection::LeftToRight,
207 }],
208 };
209
210 let strategies = TestStrategyRegistry {
211 values: vec![StrategyDescriptor {
212 id: StrategyId::new("synonym").expect("valid"),
213 display_name: "Synonym".to_string(),
214 required_capabilities: Vec::new(),
215 }],
216 };
217
218 let models = TestModelRegistry { values: Vec::new() };
219
220 let validated = validate_encode_request(&request, &languages, &strategies, &models)
221 .expect("request should validate");
222 assert!(validated.model.is_none());
223 }
224
225 #[test]
226 fn validate_encode_request_rejects_missing_model_when_strategy_requires_capability() {
227 let request = EncodeRequest {
228 carrier_text: "hello".to_string(),
229 payload: vec![1],
230 options: PipelineOptions {
231 language: LanguageTag::new("en").expect("valid language"),
232 strategy: StrategyId::new("probabilistic").expect("valid strategy"),
233 model_selection: None,
234 },
235 };
236
237 let languages = TestLanguageRegistry {
238 values: vec![LanguageDescriptor {
239 tag: LanguageTag::new("en").expect("valid"),
240 display_name: "English".to_string(),
241 direction: TextDirection::LeftToRight,
242 }],
243 };
244
245 let strategies = TestStrategyRegistry {
246 values: vec![StrategyDescriptor {
247 id: StrategyId::new("probabilistic").expect("valid"),
248 display_name: "Probabilistic".to_string(),
249 required_capabilities: vec![ModelCapability::TokenLogProbabilities],
250 }],
251 };
252
253 let models = TestModelRegistry { values: Vec::new() };
254
255 let error = validate_encode_request(&request, &languages, &strategies, &models)
256 .expect_err("request should fail");
257 let message = error.to_string();
258 assert!(message.contains("strategy requires a model selection"));
259 }
260
261 #[test]
262 fn validate_encode_request_rejects_model_without_required_capability() {
263 let request = EncodeRequest {
264 carrier_text: "hello".to_string(),
265 payload: vec![1],
266 options: PipelineOptions {
267 language: LanguageTag::new("fa").expect("valid language"),
268 strategy: StrategyId::new("probabilistic").expect("valid strategy"),
269 model_selection: Some(ModelSelection {
270 provider: ProviderId::new("openai").expect("valid provider"),
271 model: ModelId::new("gpt-4o-mini").expect("valid model"),
272 }),
273 },
274 };
275
276 let languages = TestLanguageRegistry {
277 values: vec![LanguageDescriptor {
278 tag: LanguageTag::new("fa").expect("valid"),
279 display_name: "Persian".to_string(),
280 direction: TextDirection::RightToLeft,
281 }],
282 };
283
284 let strategies = TestStrategyRegistry {
285 values: vec![StrategyDescriptor {
286 id: StrategyId::new("probabilistic").expect("valid"),
287 display_name: "Probabilistic".to_string(),
288 required_capabilities: vec![ModelCapability::TokenLogProbabilities],
289 }],
290 };
291
292 let models = TestModelRegistry {
293 values: vec![ModelDescriptor {
294 provider: ProviderId::new("openai").expect("valid provider"),
295 model: ModelId::new("gpt-4o-mini").expect("valid model"),
296 display_name: "GPT-4o Mini".to_string(),
297 supported_languages: vec![LanguageTag::new("fa").expect("valid")],
298 capabilities: vec![ModelCapability::StreamingGeneration],
299 }],
300 };
301
302 let error = validate_encode_request(&request, &languages, &strategies, &models)
303 .expect_err("request should fail");
304 let message = error.to_string();
305 assert!(message.contains("missing required capability"));
306 }
307
308 #[test]
309 fn validate_encode_request_accepts_supported_model_and_capabilities() {
310 let request = EncodeRequest {
311 carrier_text: "hello".to_string(),
312 payload: vec![1],
313 options: PipelineOptions {
314 language: LanguageTag::new("fa").expect("valid language"),
315 strategy: StrategyId::new("probabilistic").expect("valid strategy"),
316 model_selection: Some(ModelSelection {
317 provider: ProviderId::new("openai").expect("valid provider"),
318 model: ModelId::new("gpt-4o-mini").expect("valid model"),
319 }),
320 },
321 };
322
323 let languages = TestLanguageRegistry {
324 values: vec![LanguageDescriptor {
325 tag: LanguageTag::new("fa").expect("valid"),
326 display_name: "Persian".to_string(),
327 direction: TextDirection::RightToLeft,
328 }],
329 };
330
331 let strategies = TestStrategyRegistry {
332 values: vec![StrategyDescriptor {
333 id: StrategyId::new("probabilistic").expect("valid"),
334 display_name: "Probabilistic".to_string(),
335 required_capabilities: vec![ModelCapability::TokenLogProbabilities],
336 }],
337 };
338
339 let models = TestModelRegistry {
340 values: vec![ModelDescriptor {
341 provider: ProviderId::new("openai").expect("valid provider"),
342 model: ModelId::new("gpt-4o-mini").expect("valid model"),
343 display_name: "GPT-4o Mini".to_string(),
344 supported_languages: vec![LanguageTag::new("fa").expect("valid")],
345 capabilities: vec![
346 ModelCapability::TokenLogProbabilities,
347 ModelCapability::StreamingGeneration,
348 ],
349 }],
350 };
351
352 let validated = validate_encode_request(&request, &languages, &strategies, &models)
353 .expect("request should validate");
354
355 let model = validated.model.expect("validated model should exist");
356 assert_eq!(model.display_name, "GPT-4o Mini");
357 assert_eq!(validated.language.display_name, "Persian");
358 }
359
360 #[test]
361 fn validate_decode_request_rejects_model_with_unsupported_language() {
362 let request = DecodeRequest {
363 stego_text: "salam".to_string(),
364 options: PipelineOptions {
365 language: LanguageTag::new("fa").expect("valid language"),
366 strategy: StrategyId::new("probabilistic").expect("valid strategy"),
367 model_selection: Some(ModelSelection {
368 provider: ProviderId::new("openai").expect("valid provider"),
369 model: ModelId::new("gpt-4o-mini").expect("valid model"),
370 }),
371 },
372 };
373
374 let languages = TestLanguageRegistry {
375 values: vec![LanguageDescriptor {
376 tag: LanguageTag::new("fa").expect("valid"),
377 display_name: "Persian".to_string(),
378 direction: TextDirection::RightToLeft,
379 }],
380 };
381
382 let strategies = TestStrategyRegistry {
383 values: vec![StrategyDescriptor {
384 id: StrategyId::new("probabilistic").expect("valid"),
385 display_name: "Probabilistic".to_string(),
386 required_capabilities: vec![ModelCapability::TokenLogProbabilities],
387 }],
388 };
389
390 let models = TestModelRegistry {
391 values: vec![ModelDescriptor {
392 provider: ProviderId::new("openai").expect("valid provider"),
393 model: ModelId::new("gpt-4o-mini").expect("valid model"),
394 display_name: "GPT-4o Mini".to_string(),
395 supported_languages: vec![LanguageTag::new("en").expect("valid")],
396 capabilities: vec![ModelCapability::TokenLogProbabilities],
397 }],
398 };
399
400 let error = validate_decode_request(&request, &languages, &strategies, &models)
401 .expect_err("request should fail");
402 assert!(error.to_string().contains("does not support language"));
403 }
404
405 #[test]
406 fn validate_decode_request_accepts_supported_configuration() {
407 let request = DecodeRequest {
408 stego_text: "salam".to_string(),
409 options: PipelineOptions {
410 language: LanguageTag::new("fa").expect("valid language"),
411 strategy: StrategyId::new("probabilistic").expect("valid strategy"),
412 model_selection: Some(ModelSelection {
413 provider: ProviderId::new("openai").expect("valid provider"),
414 model: ModelId::new("gpt-4o-mini").expect("valid model"),
415 }),
416 },
417 };
418
419 let languages = TestLanguageRegistry {
420 values: vec![LanguageDescriptor {
421 tag: LanguageTag::new("fa").expect("valid"),
422 display_name: "Persian".to_string(),
423 direction: TextDirection::RightToLeft,
424 }],
425 };
426
427 let strategies = TestStrategyRegistry {
428 values: vec![StrategyDescriptor {
429 id: StrategyId::new("probabilistic").expect("valid"),
430 display_name: "Probabilistic".to_string(),
431 required_capabilities: vec![ModelCapability::TokenLogProbabilities],
432 }],
433 };
434
435 let models = TestModelRegistry {
436 values: vec![ModelDescriptor {
437 provider: ProviderId::new("openai").expect("valid provider"),
438 model: ModelId::new("gpt-4o-mini").expect("valid model"),
439 display_name: "GPT-4o Mini".to_string(),
440 supported_languages: vec![LanguageTag::new("fa").expect("valid")],
441 capabilities: vec![ModelCapability::TokenLogProbabilities],
442 }],
443 };
444
445 let validated = validate_decode_request(&request, &languages, &strategies, &models)
446 .expect("request should validate");
447 assert_eq!(validated.language.display_name, "Persian");
448 assert_eq!(validated.strategy.display_name, "Probabilistic");
449 assert!(validated.model.is_some());
450 }
451}