langextract_rust/
factory.rs1use crate::{
4 data::ExampleData,
5 exceptions::{LangExtractError, LangExtractResult},
6 inference::BaseLanguageModel,
7 providers::{create_provider, ProviderConfig},
8 ExtractConfig,
9};
10
11#[cfg(test)]
12use crate::providers::ProviderType;
13
14pub async fn create_model(
16 config: &ExtractConfig,
17 examples: Option<&[ExampleData]>,
18) -> LangExtractResult<Box<dyn BaseLanguageModel>> {
19 let provider_config = create_provider_config(config)?;
21
22 let mut provider = create_provider(provider_config)?;
24
25 if let Some(example_data) = examples {
27 if config.use_schema_constraints && !example_data.is_empty() {
28 if let Some(schema_class) = provider.get_schema_class() {
29 provider.apply_schema(Some(schema_class));
31 }
32 }
33 }
34
35 provider.set_fence_output(config.fence_output);
37
38 Ok(Box::new(provider))
39}
40
41fn create_provider_config(config: &ExtractConfig) -> LangExtractResult<ProviderConfig> {
43 if let Some(provider_config_value) = config.language_model_params.get("provider_config") {
45 if let Ok(provider_config) = serde_json::from_value::<ProviderConfig>(provider_config_value.clone()) {
46 return Ok(provider_config);
47 }
48 }
49
50 Err(LangExtractError::configuration(
52 "Provider configuration is required. Please specify a provider either:\n\
53 1. Via CLI: --provider <openai|ollama|custom>\n\
54 2. Via config: Set language_model_params['provider_config']\n\
55 3. Via ProviderConfig in code\n\n\
56 Auto-detection based on model names has been removed for explicit configuration."
57 ))
58}
59
60
61
62#[cfg(test)]
63mod tests {
64 use super::*;
65
66 #[test]
67 fn test_explicit_provider_required() {
68 let config = ExtractConfig {
69 model_id: "mistral".to_string(),
70 api_key: None,
71 ..Default::default()
72 };
73
74 let result = create_provider_config(&config);
75 assert!(result.is_err());
76 assert!(result.unwrap_err().to_string().contains("Provider configuration is required"));
77 }
78
79 #[test]
80 fn test_create_provider_config_with_explicit_config() {
81 let provider_config = ProviderConfig::ollama("mistral", Some("http://localhost:11434".to_string()));
82
83 let mut config = ExtractConfig {
84 model_id: "mistral".to_string(),
85 api_key: None,
86 ..Default::default()
87 };
88
89 config.language_model_params.insert(
91 "provider_config".to_string(),
92 serde_json::to_value(&provider_config).unwrap()
93 );
94
95 let result_config = create_provider_config(&config).unwrap();
96 assert_eq!(result_config.provider_type, ProviderType::Ollama);
97 assert_eq!(result_config.model, "mistral");
98 assert_eq!(result_config.base_url, "http://localhost:11434");
99 }
100}