Skip to main content

linguasteg_core/
validation.rs

1use crate::{
2    CoreError, CoreResult, DecodeRequest, EncodeRequest, LanguageDescriptor, LanguageRegistry,
3    ModelCapability, ModelDescriptor, ModelRegistry, PipelineOptions, StrategyDescriptor,
4    StrategyRegistry,
5};
6
7#[derive(Debug, Clone, PartialEq, Eq)]
8pub struct ValidatedEncodeRequest {
9    pub language: LanguageDescriptor,
10    pub strategy: StrategyDescriptor,
11    pub model: Option<ModelDescriptor>,
12}
13
14#[derive(Debug, Clone, PartialEq, Eq)]
15pub struct ValidatedDecodeRequest {
16    pub language: LanguageDescriptor,
17    pub strategy: StrategyDescriptor,
18    pub model: Option<ModelDescriptor>,
19}
20
21pub fn validate_encode_request(
22    request: &EncodeRequest,
23    language_registry: &dyn LanguageRegistry,
24    strategy_registry: &dyn StrategyRegistry,
25    model_registry: &dyn ModelRegistry,
26) -> CoreResult<ValidatedEncodeRequest> {
27    let validated = validate_pipeline_options(
28        &request.options,
29        language_registry,
30        strategy_registry,
31        model_registry,
32    )?;
33
34    Ok(ValidatedEncodeRequest {
35        language: validated.language,
36        strategy: validated.strategy,
37        model: validated.model,
38    })
39}
40
41pub fn validate_decode_request(
42    request: &DecodeRequest,
43    language_registry: &dyn LanguageRegistry,
44    strategy_registry: &dyn StrategyRegistry,
45    model_registry: &dyn ModelRegistry,
46) -> CoreResult<ValidatedDecodeRequest> {
47    let validated = validate_pipeline_options(
48        &request.options,
49        language_registry,
50        strategy_registry,
51        model_registry,
52    )?;
53
54    Ok(ValidatedDecodeRequest {
55        language: validated.language,
56        strategy: validated.strategy,
57        model: validated.model,
58    })
59}
60
61#[derive(Debug, Clone, PartialEq, Eq)]
62struct ValidatedPipelineOptions {
63    language: LanguageDescriptor,
64    strategy: StrategyDescriptor,
65    model: Option<ModelDescriptor>,
66}
67
68fn validate_pipeline_options(
69    options: &PipelineOptions,
70    language_registry: &dyn LanguageRegistry,
71    strategy_registry: &dyn StrategyRegistry,
72    model_registry: &dyn ModelRegistry,
73) -> CoreResult<ValidatedPipelineOptions> {
74    let language = language_registry
75        .language(&options.language)
76        .cloned()
77        .ok_or_else(|| CoreError::UnsupportedLanguage(options.language.to_string()))?;
78
79    let strategy = strategy_registry
80        .strategy(&options.strategy)
81        .cloned()
82        .ok_or_else(|| CoreError::UnsupportedStrategy(options.strategy.to_string()))?;
83
84    let model = match &options.model_selection {
85        Some(selection) => {
86            let model = model_registry
87                .model(&selection.provider, &selection.model)
88                .cloned()
89                .ok_or_else(|| CoreError::UnsupportedModel {
90                    provider: selection.provider.to_string(),
91                    model: selection.model.to_string(),
92                })?;
93
94            ensure_model_supports_language(&model, &options.language)?;
95            ensure_required_capabilities(&model, &strategy)?;
96
97            Some(model)
98        }
99        None => {
100            if strategy.required_capabilities.is_empty() {
101                None
102            } else {
103                return Err(CoreError::StrategyRequiresModel(strategy.id.to_string()));
104            }
105        }
106    };
107
108    Ok(ValidatedPipelineOptions {
109        language,
110        strategy,
111        model,
112    })
113}
114
115fn ensure_model_supports_language(
116    model: &ModelDescriptor,
117    language: &crate::LanguageTag,
118) -> CoreResult<()> {
119    if model.supported_languages.iter().any(|tag| tag == language) {
120        Ok(())
121    } else {
122        Err(CoreError::ModelDoesNotSupportLanguage {
123            provider: model.provider.to_string(),
124            model: model.model.to_string(),
125            language: language.to_string(),
126        })
127    }
128}
129
130fn ensure_required_capabilities(
131    model: &ModelDescriptor,
132    strategy: &StrategyDescriptor,
133) -> CoreResult<()> {
134    for capability in &strategy.required_capabilities {
135        if !model.capabilities.contains(capability) {
136            return Err(CoreError::ModelMissingCapability {
137                provider: model.provider.to_string(),
138                model: model.model.to_string(),
139                capability: capability_name(*capability),
140            });
141        }
142    }
143
144    Ok(())
145}
146
147fn capability_name(capability: ModelCapability) -> &'static str {
148    capability.as_str()
149}
150
151#[cfg(test)]
152mod tests {
153    use crate::{
154        DecodeRequest, EncodeRequest, LanguageDescriptor, LanguageRegistry, LanguageTag,
155        ModelCapability, ModelDescriptor, ModelId, ModelRegistry, ModelSelection, PipelineOptions,
156        ProviderId, StrategyDescriptor, StrategyId, StrategyRegistry, TextDirection,
157        validate_decode_request, validate_encode_request,
158    };
159
160    struct TestLanguageRegistry {
161        values: Vec<LanguageDescriptor>,
162    }
163
164    impl LanguageRegistry for TestLanguageRegistry {
165        fn all_languages(&self) -> &[LanguageDescriptor] {
166            &self.values
167        }
168    }
169
170    struct TestStrategyRegistry {
171        values: Vec<StrategyDescriptor>,
172    }
173
174    impl StrategyRegistry for TestStrategyRegistry {
175        fn all_strategies(&self) -> &[StrategyDescriptor] {
176            &self.values
177        }
178    }
179
180    struct TestModelRegistry {
181        values: Vec<ModelDescriptor>,
182    }
183
184    impl ModelRegistry for TestModelRegistry {
185        fn all_models(&self) -> &[ModelDescriptor] {
186            &self.values
187        }
188    }
189
190    #[test]
191    fn validate_encode_request_accepts_strategy_without_model_requirement() {
192        let request = EncodeRequest {
193            carrier_text: "hello".to_string(),
194            payload: vec![1, 2, 3],
195            options: PipelineOptions {
196                language: LanguageTag::new("en").expect("valid language"),
197                strategy: StrategyId::new("synonym").expect("valid strategy"),
198                model_selection: None,
199            },
200        };
201
202        let languages = TestLanguageRegistry {
203            values: vec![LanguageDescriptor {
204                tag: LanguageTag::new("en").expect("valid"),
205                display_name: "English".to_string(),
206                direction: TextDirection::LeftToRight,
207            }],
208        };
209
210        let strategies = TestStrategyRegistry {
211            values: vec![StrategyDescriptor {
212                id: StrategyId::new("synonym").expect("valid"),
213                display_name: "Synonym".to_string(),
214                required_capabilities: Vec::new(),
215            }],
216        };
217
218        let models = TestModelRegistry { values: Vec::new() };
219
220        let validated = validate_encode_request(&request, &languages, &strategies, &models)
221            .expect("request should validate");
222        assert!(validated.model.is_none());
223    }
224
225    #[test]
226    fn validate_encode_request_rejects_missing_model_when_strategy_requires_capability() {
227        let request = EncodeRequest {
228            carrier_text: "hello".to_string(),
229            payload: vec![1],
230            options: PipelineOptions {
231                language: LanguageTag::new("en").expect("valid language"),
232                strategy: StrategyId::new("probabilistic").expect("valid strategy"),
233                model_selection: None,
234            },
235        };
236
237        let languages = TestLanguageRegistry {
238            values: vec![LanguageDescriptor {
239                tag: LanguageTag::new("en").expect("valid"),
240                display_name: "English".to_string(),
241                direction: TextDirection::LeftToRight,
242            }],
243        };
244
245        let strategies = TestStrategyRegistry {
246            values: vec![StrategyDescriptor {
247                id: StrategyId::new("probabilistic").expect("valid"),
248                display_name: "Probabilistic".to_string(),
249                required_capabilities: vec![ModelCapability::TokenLogProbabilities],
250            }],
251        };
252
253        let models = TestModelRegistry { values: Vec::new() };
254
255        let error = validate_encode_request(&request, &languages, &strategies, &models)
256            .expect_err("request should fail");
257        let message = error.to_string();
258        assert!(message.contains("strategy requires a model selection"));
259    }
260
261    #[test]
262    fn validate_encode_request_rejects_model_without_required_capability() {
263        let request = EncodeRequest {
264            carrier_text: "hello".to_string(),
265            payload: vec![1],
266            options: PipelineOptions {
267                language: LanguageTag::new("fa").expect("valid language"),
268                strategy: StrategyId::new("probabilistic").expect("valid strategy"),
269                model_selection: Some(ModelSelection {
270                    provider: ProviderId::new("openai").expect("valid provider"),
271                    model: ModelId::new("gpt-4o-mini").expect("valid model"),
272                }),
273            },
274        };
275
276        let languages = TestLanguageRegistry {
277            values: vec![LanguageDescriptor {
278                tag: LanguageTag::new("fa").expect("valid"),
279                display_name: "Persian".to_string(),
280                direction: TextDirection::RightToLeft,
281            }],
282        };
283
284        let strategies = TestStrategyRegistry {
285            values: vec![StrategyDescriptor {
286                id: StrategyId::new("probabilistic").expect("valid"),
287                display_name: "Probabilistic".to_string(),
288                required_capabilities: vec![ModelCapability::TokenLogProbabilities],
289            }],
290        };
291
292        let models = TestModelRegistry {
293            values: vec![ModelDescriptor {
294                provider: ProviderId::new("openai").expect("valid provider"),
295                model: ModelId::new("gpt-4o-mini").expect("valid model"),
296                display_name: "GPT-4o Mini".to_string(),
297                supported_languages: vec![LanguageTag::new("fa").expect("valid")],
298                capabilities: vec![ModelCapability::StreamingGeneration],
299            }],
300        };
301
302        let error = validate_encode_request(&request, &languages, &strategies, &models)
303            .expect_err("request should fail");
304        let message = error.to_string();
305        assert!(message.contains("missing required capability"));
306    }
307
308    #[test]
309    fn validate_encode_request_accepts_supported_model_and_capabilities() {
310        let request = EncodeRequest {
311            carrier_text: "hello".to_string(),
312            payload: vec![1],
313            options: PipelineOptions {
314                language: LanguageTag::new("fa").expect("valid language"),
315                strategy: StrategyId::new("probabilistic").expect("valid strategy"),
316                model_selection: Some(ModelSelection {
317                    provider: ProviderId::new("openai").expect("valid provider"),
318                    model: ModelId::new("gpt-4o-mini").expect("valid model"),
319                }),
320            },
321        };
322
323        let languages = TestLanguageRegistry {
324            values: vec![LanguageDescriptor {
325                tag: LanguageTag::new("fa").expect("valid"),
326                display_name: "Persian".to_string(),
327                direction: TextDirection::RightToLeft,
328            }],
329        };
330
331        let strategies = TestStrategyRegistry {
332            values: vec![StrategyDescriptor {
333                id: StrategyId::new("probabilistic").expect("valid"),
334                display_name: "Probabilistic".to_string(),
335                required_capabilities: vec![ModelCapability::TokenLogProbabilities],
336            }],
337        };
338
339        let models = TestModelRegistry {
340            values: vec![ModelDescriptor {
341                provider: ProviderId::new("openai").expect("valid provider"),
342                model: ModelId::new("gpt-4o-mini").expect("valid model"),
343                display_name: "GPT-4o Mini".to_string(),
344                supported_languages: vec![LanguageTag::new("fa").expect("valid")],
345                capabilities: vec![
346                    ModelCapability::TokenLogProbabilities,
347                    ModelCapability::StreamingGeneration,
348                ],
349            }],
350        };
351
352        let validated = validate_encode_request(&request, &languages, &strategies, &models)
353            .expect("request should validate");
354
355        let model = validated.model.expect("validated model should exist");
356        assert_eq!(model.display_name, "GPT-4o Mini");
357        assert_eq!(validated.language.display_name, "Persian");
358    }
359
360    #[test]
361    fn validate_decode_request_rejects_model_with_unsupported_language() {
362        let request = DecodeRequest {
363            stego_text: "salam".to_string(),
364            options: PipelineOptions {
365                language: LanguageTag::new("fa").expect("valid language"),
366                strategy: StrategyId::new("probabilistic").expect("valid strategy"),
367                model_selection: Some(ModelSelection {
368                    provider: ProviderId::new("openai").expect("valid provider"),
369                    model: ModelId::new("gpt-4o-mini").expect("valid model"),
370                }),
371            },
372        };
373
374        let languages = TestLanguageRegistry {
375            values: vec![LanguageDescriptor {
376                tag: LanguageTag::new("fa").expect("valid"),
377                display_name: "Persian".to_string(),
378                direction: TextDirection::RightToLeft,
379            }],
380        };
381
382        let strategies = TestStrategyRegistry {
383            values: vec![StrategyDescriptor {
384                id: StrategyId::new("probabilistic").expect("valid"),
385                display_name: "Probabilistic".to_string(),
386                required_capabilities: vec![ModelCapability::TokenLogProbabilities],
387            }],
388        };
389
390        let models = TestModelRegistry {
391            values: vec![ModelDescriptor {
392                provider: ProviderId::new("openai").expect("valid provider"),
393                model: ModelId::new("gpt-4o-mini").expect("valid model"),
394                display_name: "GPT-4o Mini".to_string(),
395                supported_languages: vec![LanguageTag::new("en").expect("valid")],
396                capabilities: vec![ModelCapability::TokenLogProbabilities],
397            }],
398        };
399
400        let error = validate_decode_request(&request, &languages, &strategies, &models)
401            .expect_err("request should fail");
402        assert!(error.to_string().contains("does not support language"));
403    }
404
405    #[test]
406    fn validate_decode_request_accepts_supported_configuration() {
407        let request = DecodeRequest {
408            stego_text: "salam".to_string(),
409            options: PipelineOptions {
410                language: LanguageTag::new("fa").expect("valid language"),
411                strategy: StrategyId::new("probabilistic").expect("valid strategy"),
412                model_selection: Some(ModelSelection {
413                    provider: ProviderId::new("openai").expect("valid provider"),
414                    model: ModelId::new("gpt-4o-mini").expect("valid model"),
415                }),
416            },
417        };
418
419        let languages = TestLanguageRegistry {
420            values: vec![LanguageDescriptor {
421                tag: LanguageTag::new("fa").expect("valid"),
422                display_name: "Persian".to_string(),
423                direction: TextDirection::RightToLeft,
424            }],
425        };
426
427        let strategies = TestStrategyRegistry {
428            values: vec![StrategyDescriptor {
429                id: StrategyId::new("probabilistic").expect("valid"),
430                display_name: "Probabilistic".to_string(),
431                required_capabilities: vec![ModelCapability::TokenLogProbabilities],
432            }],
433        };
434
435        let models = TestModelRegistry {
436            values: vec![ModelDescriptor {
437                provider: ProviderId::new("openai").expect("valid provider"),
438                model: ModelId::new("gpt-4o-mini").expect("valid model"),
439                display_name: "GPT-4o Mini".to_string(),
440                supported_languages: vec![LanguageTag::new("fa").expect("valid")],
441                capabilities: vec![ModelCapability::TokenLogProbabilities],
442            }],
443        };
444
445        let validated = validate_decode_request(&request, &languages, &strategies, &models)
446            .expect("request should validate");
447        assert_eq!(validated.language.display_name, "Persian");
448        assert_eq!(validated.strategy.display_name, "Probabilistic");
449        assert!(validated.model.is_some());
450    }
451}