use crate::llm::api::{DeltaSender, LlmRequestPayload, LlmResult};
use crate::llm::provider::{LlmProvider, LlmProviderChat};
use crate::value::VmError;
pub(crate) fn gpt_generation(model: &str) -> Option<(u32, u32)> {
let lower = model.to_lowercase();
let stripped = match lower.rsplit_once('/') {
Some((_, tail)) => tail,
None => lower.as_str(),
};
let needle = "gpt-";
let idx = stripped.find(needle)?;
let tail = &stripped[idx + needle.len()..];
if let Some((major, rest)) = tail.split_once('.') {
if let Ok(major) = major.parse::<u32>() {
let minor_str: String = rest.chars().take_while(|c| c.is_ascii_digit()).collect();
if let Ok(minor) = minor_str.parse::<u32>() {
return Some((major, minor));
}
}
}
let mut parts = tail.split('-');
if let Some(major_str) = parts.next() {
if let Ok(major) = major_str.parse::<u32>() {
if let Some(minor_str) = parts.next() {
if let Ok(minor) = minor_str.parse::<u32>() {
let minor = if minor >= 1000 { 0 } else { minor };
return Some((major, minor));
}
}
return Some((major, 0));
}
}
None
}
#[allow(dead_code)]
pub(crate) fn gpt_model_supports_tool_search(model: &str) -> bool {
match gpt_generation(model) {
Some((major, minor)) => (major, minor) >= (5, 4),
None => false,
}
}
pub(crate) struct OpenAiCompatibleProvider {
provider_name: String,
}
impl OpenAiCompatibleProvider {
pub(crate) fn new(name: String) -> Self {
Self {
provider_name: name,
}
}
}
impl LlmProvider for OpenAiCompatibleProvider {
fn name(&self) -> &str {
&self.provider_name
}
fn transform_request(&self, body: &mut serde_json::Value) {
if self.provider_name.to_lowercase().contains("openrouter") {
if let Some(obj) = body.as_object_mut() {
obj.remove("chat_template_kwargs");
}
}
}
}
impl LlmProviderChat for OpenAiCompatibleProvider {
fn chat<'a>(
&'a self,
request: &'a LlmRequestPayload,
delta_tx: Option<DeltaSender>,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<LlmResult, VmError>> + 'a>> {
Box::pin(self.chat_impl(request, delta_tx))
}
}
impl OpenAiCompatibleProvider {
pub(crate) fn build_request_body(
opts: &LlmRequestPayload,
force_string_content: bool,
) -> serde_json::Value {
let mut msgs = Vec::new();
if let Some(ref sys) = opts.system {
msgs.push(serde_json::json!({"role": "system", "content": sys}));
}
msgs.extend(opts.messages.iter().cloned());
if let Some(ref prefill) = opts.prefill {
msgs.push(serde_json::json!({
"role": "assistant",
"content": prefill,
}));
}
msgs = crate::llm::api::normalize_openai_style_messages(msgs, force_string_content);
let mut body = serde_json::json!({
"model": opts.model,
"messages": msgs,
});
if opts.max_tokens > 0 {
body["max_tokens"] = serde_json::json!(opts.max_tokens);
}
if let Some(temp) = opts.temperature {
body["temperature"] = serde_json::json!(temp);
}
if let Some(top_p) = opts.top_p {
body["top_p"] = serde_json::json!(top_p);
}
if let Some(ref stop) = opts.stop {
body["stop"] = serde_json::json!(stop);
}
if let Some(seed) = opts.seed {
body["seed"] = serde_json::json!(seed);
}
if let Some(fp) = opts.frequency_penalty {
body["frequency_penalty"] = serde_json::json!(fp);
}
if let Some(pp) = opts.presence_penalty {
body["presence_penalty"] = serde_json::json!(pp);
}
if opts.response_format.as_deref() == Some("json") {
if let Some(ref schema) = opts.json_schema {
body["response_format"] = serde_json::json!({
"type": "json_schema",
"json_schema": {
"name": "response",
"schema": schema,
"strict": true,
}
});
} else {
body["response_format"] = serde_json::json!({"type": "json_object"});
}
}
if let Some(ref tools) = opts.native_tools {
if !tools.is_empty() {
body["tools"] = serde_json::json!(tools);
}
}
if let Some(ref tc) = opts.tool_choice {
body["tool_choice"] = tc.clone();
}
let mut chat_template_kwargs = serde_json::json!({
"enable_thinking": opts.thinking.is_some(),
});
if opts.prefill.is_some() {
chat_template_kwargs["add_generation_prompt"] = serde_json::json!(false);
chat_template_kwargs["continue_final_message"] = serde_json::json!(true);
}
body["chat_template_kwargs"] = chat_template_kwargs;
body
}
pub(crate) async fn chat_impl(
&self,
request: &LlmRequestPayload,
delta_tx: Option<DeltaSender>,
) -> Result<LlmResult, VmError> {
let mut body = Self::build_request_body(request, false);
self.transform_request(&mut body);
crate::llm::api::vm_call_llm_api_with_body(
request, delta_tx, body, false, false, )
.await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn tool_search_supported_for_gpt_5_4_and_up() {
assert!(gpt_model_supports_tool_search("gpt-5.4"));
assert!(gpt_model_supports_tool_search("gpt-5.4-preview"));
assert!(gpt_model_supports_tool_search("gpt-5.4-turbo"));
assert!(gpt_model_supports_tool_search("gpt-5-4"));
assert!(gpt_model_supports_tool_search("gpt-5.5"));
assert!(gpt_model_supports_tool_search("gpt-6.0"));
}
#[test]
fn tool_search_unsupported_for_pre_5_4() {
assert!(!gpt_model_supports_tool_search("gpt-4o"));
assert!(!gpt_model_supports_tool_search("gpt-4.1"));
assert!(!gpt_model_supports_tool_search("gpt-4-turbo"));
assert!(!gpt_model_supports_tool_search("gpt-3.5-turbo"));
assert!(!gpt_model_supports_tool_search("gpt-5.0"));
assert!(!gpt_model_supports_tool_search("gpt-5.3-preview"));
assert!(!gpt_model_supports_tool_search("gpt-5"));
}
#[test]
fn tool_search_unsupported_for_non_gpt() {
assert!(!gpt_model_supports_tool_search("claude-opus-4-7"));
assert!(!gpt_model_supports_tool_search("llama-3.1-70b"));
assert!(!gpt_model_supports_tool_search(""));
}
#[test]
fn gpt_generation_handles_openrouter_prefix() {
assert_eq!(gpt_generation("openai/gpt-5.4-preview"), Some((5, 4)));
assert_eq!(gpt_generation("azure/gpt-5.5-turbo"), Some((5, 5)));
assert!(gpt_model_supports_tool_search("openai/gpt-5.4"));
assert!(!gpt_model_supports_tool_search("openai/gpt-4o"));
}
#[test]
fn gpt_generation_ignores_date_suffix_as_minor() {
assert_eq!(gpt_generation("gpt-5-20260115"), Some((5, 0)));
assert!(!gpt_model_supports_tool_search("gpt-5-20260115"));
}
#[test]
fn native_tool_search_variants_lists_hosted_first() {
let provider = OpenAiCompatibleProvider::new("openai".to_string());
let variants = provider.native_tool_search_variants("gpt-5.4-preview");
assert_eq!(variants, vec!["hosted".to_string(), "client".to_string()]);
}
#[test]
fn native_tool_search_variants_empty_for_old_model() {
let provider = OpenAiCompatibleProvider::new("openai".to_string());
assert!(provider.native_tool_search_variants("gpt-4o").is_empty());
}
#[test]
fn supports_defer_loading_matches_tool_search_gate() {
let provider = OpenAiCompatibleProvider::new("openai".to_string());
assert!(provider.supports_defer_loading("gpt-5.4"));
assert!(!provider.supports_defer_loading("gpt-4o"));
}
}