langextract_rust/
factory.rs

1//! Factory for creating language model instances.
2
3use crate::{
4    data::ExampleData,
5    exceptions::{LangExtractError, LangExtractResult},
6    inference::BaseLanguageModel,
7    providers::{create_provider, ProviderConfig},
8    ExtractConfig,
9};
10
11#[cfg(test)]
12use crate::providers::ProviderType;
13
14/// Create a language model based on configuration
15pub async fn create_model(
16    config: &ExtractConfig,
17    examples: Option<&[ExampleData]>,
18) -> LangExtractResult<Box<dyn BaseLanguageModel>> {
19    // Determine provider type and configuration from the ExtractConfig
20    let provider_config = create_provider_config(config)?;
21    
22    // Create the provider
23    let mut provider = create_provider(provider_config)?;
24    
25    // Apply schema if examples are provided
26    if let Some(example_data) = examples {
27        if config.use_schema_constraints && !example_data.is_empty() {
28            if let Some(schema_class) = provider.get_schema_class() {
29                // For now, we'll use a basic schema
30                provider.apply_schema(Some(schema_class));
31            }
32        }
33    }
34    
35    // Set fence output preference
36    provider.set_fence_output(config.fence_output);
37    
38    Ok(Box::new(provider))
39}
40
41/// Create provider configuration from ExtractConfig
42fn create_provider_config(config: &ExtractConfig) -> LangExtractResult<ProviderConfig> {
43    // Check if provider configuration is already specified in language_model_params
44    if let Some(provider_config_value) = config.language_model_params.get("provider_config") {
45        if let Ok(provider_config) = serde_json::from_value::<ProviderConfig>(provider_config_value.clone()) {
46            return Ok(provider_config);
47        }
48    }
49    
50    // Provider configuration is required - no auto-detection
51    Err(LangExtractError::configuration(
52        "Provider configuration is required. Please specify a provider either:\n\
53         1. Via CLI: --provider <openai|ollama|custom>\n\
54         2. Via config: Set language_model_params['provider_config']\n\
55         3. Via ProviderConfig in code\n\n\
56         Auto-detection based on model names has been removed for explicit configuration."
57    ))
58}
59
60
61
62#[cfg(test)]
63mod tests {
64    use super::*;
65
66    #[test]
67    fn test_explicit_provider_required() {
68        let config = ExtractConfig {
69            model_id: "mistral".to_string(),
70            api_key: None,
71            ..Default::default()
72        };
73        
74        let result = create_provider_config(&config);
75        assert!(result.is_err());
76        assert!(result.unwrap_err().to_string().contains("Provider configuration is required"));
77    }
78
79    #[test]
80    fn test_create_provider_config_with_explicit_config() {
81        let provider_config = ProviderConfig::ollama("mistral", Some("http://localhost:11434".to_string()));
82        
83        let mut config = ExtractConfig {
84            model_id: "mistral".to_string(),
85            api_key: None,
86            ..Default::default()
87        };
88        
89        // Set explicit provider config
90        config.language_model_params.insert(
91            "provider_config".to_string(),
92            serde_json::to_value(&provider_config).unwrap()
93        );
94        
95        let result_config = create_provider_config(&config).unwrap();
96        assert_eq!(result_config.provider_type, ProviderType::Ollama);
97        assert_eq!(result_config.model, "mistral");
98        assert_eq!(result_config.base_url, "http://localhost:11434");
99    }
100}