use crate::brain::Provider;
use crate::brain::provider::custom_openai_compatible::OpenAIProvider;
use crate::brain::provider::factory::{create_provider, create_provider_by_name};
use crate::config::{Config, ProviderConfig, ProviderConfigs};
use std::collections::BTreeMap;
fn config_with_context_window(context_window: Option<u32>) -> Config {
let mut custom_map = BTreeMap::new();
custom_map.insert(
"lmstudio".to_string(),
ProviderConfig {
enabled: true,
api_key: None,
base_url: Some("http://localhost:1234/v1".to_string()),
default_model: Some("my-local-model".to_string()),
context_window,
..Default::default()
},
);
Config {
providers: ProviderConfigs {
custom: Some(custom_map),
..Default::default()
},
..Default::default()
}
}
#[test]
fn with_context_window_overrides_unknown_model() {
let provider = OpenAIProvider::with_base_url(
String::new(),
"http://localhost:1234/v1/chat/completions".to_string(),
)
.with_name("local")
.with_context_window(32_000);
assert_eq!(provider.context_window("my-custom-model"), Some(32_000));
}
#[test]
fn with_context_window_overrides_known_model() {
let provider = OpenAIProvider::with_base_url(
String::new(),
"http://localhost:1234/v1/chat/completions".to_string(),
)
.with_context_window(64_000);
assert_eq!(provider.context_window("gpt-4o"), Some(64_000));
}
#[test]
fn without_context_window_returns_none_for_unknown() {
let provider = OpenAIProvider::with_base_url(
String::new(),
"http://localhost:1234/v1/chat/completions".to_string(),
);
assert_eq!(provider.context_window("my-custom-model"), None);
}
#[test]
fn without_context_window_returns_heuristic_for_known() {
let provider = OpenAIProvider::with_base_url(
String::new(),
"http://localhost:1234/v1/chat/completions".to_string(),
);
assert_eq!(provider.context_window("gpt-4o"), Some(128_000));
}
#[tokio::test]
async fn factory_passes_context_window_to_provider() {
let config = config_with_context_window(Some(200_000));
let result = create_provider(&config).await;
assert!(result.is_ok());
let provider = result.unwrap();
assert_eq!(
provider.context_window("my-local-model"),
Some(200_000),
"Factory should pass context_window from config to provider"
);
}
#[tokio::test]
async fn factory_no_context_window_returns_none_for_unknown() {
let config = config_with_context_window(None);
let result = create_provider(&config).await;
assert!(result.is_ok());
let provider = result.unwrap();
assert_eq!(
provider.context_window("my-local-model"),
None,
"Without context_window in config, unknown model should return None"
);
}
#[tokio::test]
async fn factory_by_name_passes_context_window() {
let config = config_with_context_window(Some(16_384));
let result = create_provider_by_name(&config, "custom:lmstudio").await;
assert!(result.is_ok());
let provider = result.unwrap();
assert_eq!(provider.context_window("whatever-model"), Some(16_384));
}
#[test]
fn context_window_serializes_when_set() {
let cfg = ProviderConfig {
enabled: true,
context_window: Some(128_000),
..Default::default()
};
let toml_str = toml::to_string(&cfg).expect("serialize");
assert!(
toml_str.contains("context_window = 128000"),
"context_window should appear in serialized TOML: {}",
toml_str
);
}
#[test]
fn context_window_omitted_when_none() {
let cfg = ProviderConfig {
enabled: true,
context_window: None,
..Default::default()
};
let toml_str = toml::to_string(&cfg).expect("serialize");
assert!(
!toml_str.contains("context_window"),
"context_window should be omitted when None: {}",
toml_str
);
}
#[test]
fn context_window_deserializes_from_toml() {
let toml_str = r#"
enabled = true
context_window = 32000
"#;
let cfg: ProviderConfig = toml::from_str(toml_str).expect("deserialize");
assert_eq!(cfg.context_window, Some(32_000));
}
#[test]
fn context_window_defaults_to_none_when_missing() {
let toml_str = r#"
enabled = true
"#;
let cfg: ProviderConfig = toml::from_str(toml_str).expect("deserialize");
assert_eq!(cfg.context_window, None);
}
#[tokio::test]
async fn multiple_customs_each_get_own_context_window() {
let mut custom_map = BTreeMap::new();
custom_map.insert(
"nvidia".to_string(),
ProviderConfig {
enabled: true,
base_url: Some("https://integrate.api.nvidia.com/v1".to_string()),
default_model: Some("llama-70b".to_string()),
context_window: Some(128_000),
..Default::default()
},
);
custom_map.insert(
"ollama".to_string(),
ProviderConfig {
enabled: false,
base_url: Some("http://localhost:11434/v1".to_string()),
default_model: Some("phi3".to_string()),
context_window: Some(4_096),
..Default::default()
},
);
let config = Config {
providers: ProviderConfigs {
custom: Some(custom_map),
..Default::default()
},
..Default::default()
};
let provider = create_provider(&config).await.unwrap();
assert_eq!(provider.context_window("llama-70b"), Some(128_000));
let ollama = create_provider_by_name(&config, "custom:ollama")
.await
.unwrap();
assert_eq!(ollama.context_window("phi3"), Some(4_096));
}
#[test]
fn context_window_zero_is_valid() {
let provider = OpenAIProvider::with_base_url(
String::new(),
"http://localhost:1234/v1/chat/completions".to_string(),
)
.with_context_window(0);
assert_eq!(provider.context_window("any-model"), Some(0));
}
#[test]
fn context_window_large_value() {
let provider = OpenAIProvider::with_base_url(
String::new(),
"http://localhost:1234/v1/chat/completions".to_string(),
)
.with_context_window(2_000_000);
assert_eq!(provider.context_window("any-model"), Some(2_000_000));
}