use std::sync::LazyLock;
use prompty::interfaces::InvokerError;
use prompty::model::ModelInfo;
use serde_json::Value;
static HTTP_CLIENT: LazyLock<reqwest::Client> = LazyLock::new(reqwest::Client::new);
struct KnownModel {
prefix: &'static str,
context_window: Option<i32>,
input_modalities: &'static [&'static str],
output_modalities: &'static [&'static str],
}
static KNOWN_MODELS: &[KnownModel] = &[
KnownModel {
prefix: "gpt-4o-mini",
context_window: Some(128_000),
input_modalities: &["text", "image"],
output_modalities: &["text"],
},
KnownModel {
prefix: "gpt-4o",
context_window: Some(128_000),
input_modalities: &["text", "image"],
output_modalities: &["text"],
},
KnownModel {
prefix: "gpt-4-turbo",
context_window: Some(128_000),
input_modalities: &["text", "image"],
output_modalities: &["text"],
},
KnownModel {
prefix: "gpt-4",
context_window: Some(8_192),
input_modalities: &["text"],
output_modalities: &["text"],
},
KnownModel {
prefix: "gpt-3.5-turbo",
context_window: Some(16_385),
input_modalities: &["text"],
output_modalities: &["text"],
},
KnownModel {
prefix: "text-embedding-3-small",
context_window: Some(8_191),
input_modalities: &["text"],
output_modalities: &[],
},
KnownModel {
prefix: "text-embedding-3-large",
context_window: Some(8_191),
input_modalities: &["text"],
output_modalities: &[],
},
KnownModel {
prefix: "dall-e-3",
context_window: None,
input_modalities: &["text"],
output_modalities: &["image"],
},
];
fn find_known(id: &str) -> Option<&'static KnownModel> {
KNOWN_MODELS.iter().find(|km| id.starts_with(km.prefix))
}
fn build_models_url(connection: &Value) -> String {
let endpoint = connection
.get("endpoint")
.and_then(|e| e.as_str())
.filter(|s| !s.is_empty())
.map(String::from)
.or_else(|| {
std::env::var("OPENAI_BASE_URL")
.ok()
.filter(|s| !s.is_empty())
})
.unwrap_or_else(|| "https://api.openai.com".to_string());
let base = endpoint.trim_end_matches('/');
let path = "/v1/models";
let adjusted = if base.ends_with("/v1") || base.ends_with("/v1/") {
path.strip_prefix("/v1").unwrap_or(path)
} else {
path
};
format!("{base}{adjusted}")
}
fn get_api_key(connection: &Value) -> Result<String, InvokerError> {
if let Some(key) = connection
.get("apiKey")
.or(connection.get("api_key"))
.and_then(|k| k.as_str())
{
if !key.is_empty() {
return Ok(key.to_string());
}
}
if let Ok(key) = std::env::var("OPENAI_API_KEY") {
if !key.is_empty() {
return Ok(key);
}
}
Err(InvokerError::Execute(
"No API key found. Set OPENAI_API_KEY or configure connection.apiKey"
.to_string()
.into(),
))
}
fn parse_model_object(obj: &Value) -> ModelInfo {
let id = obj
.get("id")
.and_then(|v| v.as_str())
.unwrap_or_default()
.to_string();
let owned_by = obj
.get("owned_by")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let known = find_known(&id);
ModelInfo {
id,
display_name: None,
owned_by,
context_window: known.and_then(|k| k.context_window),
input_modalities: known.map(|k| {
k.input_modalities
.iter()
.map(|s| (*s).to_string())
.collect()
}),
output_modalities: known.map(|k| {
k.output_modalities
.iter()
.map(|s| (*s).to_string())
.collect()
}),
additional_properties: serde_json::Value::Null,
}
}
pub async fn list_models_async(connection: &Value) -> Result<Vec<ModelInfo>, InvokerError> {
let url = build_models_url(connection);
let api_key = get_api_key(connection)?;
let client = &*HTTP_CLIENT;
let response = client
.get(&url)
.header("Authorization", format!("Bearer {api_key}"))
.send()
.await
.map_err(|e| InvokerError::Execute(format!("HTTP request failed: {e}").into()))?;
if !response.status().is_success() {
let status = response.status();
let body_text = response
.text()
.await
.unwrap_or_else(|_| "unable to read body".to_string());
return Err(InvokerError::Execute(
format!("OpenAI list models error (HTTP {status}): {body_text}").into(),
));
}
let body: Value = response
.json()
.await
.map_err(|e| InvokerError::Execute(format!("Failed to parse response: {e}").into()))?;
let models = body
.get("data")
.and_then(|d| d.as_array())
.map(|arr| arr.iter().map(parse_model_object).collect())
.unwrap_or_default();
Ok(models)
}
pub fn list_models(connection: &Value) -> Result<Vec<ModelInfo>, InvokerError> {
tokio::runtime::Handle::try_current()
.map_err(|_| {
InvokerError::Execute(
"list_models requires a tokio runtime; use list_models_async instead"
.to_string()
.into(),
)
})
.and_then(|_| {
futures::executor::block_on(list_models_async(connection))
})
.or_else(|_| {
let rt = tokio::runtime::Runtime::new().map_err(|e| {
InvokerError::Execute(format!("Failed to create tokio runtime: {e}").into())
})?;
rt.block_on(list_models_async(connection))
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_build_models_url_default() {
let conn = serde_json::json!({});
let _prev = std::env::var("OPENAI_BASE_URL").ok();
unsafe { std::env::remove_var("OPENAI_BASE_URL") };
let url = build_models_url(&conn);
assert_eq!(url, "https://api.openai.com/v1/models");
}
#[test]
fn test_build_models_url_custom_endpoint() {
let conn = serde_json::json!({
"endpoint": "https://custom.openai.com/"
});
let url = build_models_url(&conn);
assert_eq!(url, "https://custom.openai.com/v1/models");
}
#[test]
fn test_build_models_url_with_v1_suffix() {
let conn = serde_json::json!({
"endpoint": "https://proxy.example.com/openai/v1"
});
let url = build_models_url(&conn);
assert_eq!(url, "https://proxy.example.com/openai/v1/models");
}
#[test]
fn test_get_api_key_from_connection() {
let conn = serde_json::json!({
"apiKey": "sk-from-conn"
});
let key = get_api_key(&conn).unwrap();
assert_eq!(key, "sk-from-conn");
}
#[test]
fn test_get_api_key_missing() {
unsafe { std::env::remove_var("OPENAI_API_KEY") };
let conn = serde_json::json!({});
let result = get_api_key(&conn);
assert!(result.is_err());
}
#[test]
fn test_find_known_gpt4o() {
let km = find_known("gpt-4o").unwrap();
assert_eq!(km.context_window, Some(128_000));
assert_eq!(km.input_modalities, &["text", "image"]);
}
#[test]
fn test_find_known_gpt4o_mini() {
let km = find_known("gpt-4o-mini-2024-07-18").unwrap();
assert_eq!(km.context_window, Some(128_000));
assert_eq!(km.prefix, "gpt-4o-mini");
}
#[test]
fn test_find_known_gpt4() {
let km = find_known("gpt-4-0613").unwrap();
assert_eq!(km.context_window, Some(8_192));
assert_eq!(km.input_modalities, &["text"]);
}
#[test]
fn test_find_known_dalle3() {
let km = find_known("dall-e-3").unwrap();
assert!(km.context_window.is_none());
assert_eq!(km.output_modalities, &["image"]);
}
#[test]
fn test_find_known_unknown_model() {
assert!(find_known("some-custom-model").is_none());
}
#[test]
fn test_parse_model_object_known() {
let obj = serde_json::json!({
"id": "gpt-4o",
"owned_by": "openai",
"object": "model"
});
let info = parse_model_object(&obj);
assert_eq!(info.id, "gpt-4o");
assert_eq!(info.owned_by.as_deref(), Some("openai"));
assert_eq!(info.context_window, Some(128_000));
assert_eq!(
info.input_modalities.as_deref(),
Some(vec!["text".to_string(), "image".to_string()].as_slice())
);
}
#[test]
fn test_parse_model_object_unknown() {
let obj = serde_json::json!({
"id": "ft:custom:user-123",
"owned_by": "user-123"
});
let info = parse_model_object(&obj);
assert_eq!(info.id, "ft:custom:user-123");
assert!(info.context_window.is_none());
assert!(info.input_modalities.is_none());
}
#[test]
fn test_parse_model_object_embedding() {
let obj = serde_json::json!({
"id": "text-embedding-3-small",
"owned_by": "openai"
});
let info = parse_model_object(&obj);
assert_eq!(info.context_window, Some(8_191));
assert_eq!(
info.output_modalities.as_deref(),
Some(Vec::<String>::new().as_slice())
);
}
}