use std::collections::HashMap;
use std::sync::Arc;
use oagw_sdk::ServiceGatewayClientV1;
use tracing::{info, warn};
use crate::config::ProviderEntry;
pub async fn register_oagw_upstreams(
gateway: &Arc<dyn ServiceGatewayClientV1>,
ctx: &modkit_security::SecurityContext,
providers: &mut HashMap<String, ProviderEntry>,
) -> anyhow::Result<()> {
for (provider_id, entry) in providers.iter_mut() {
let upstream = create_upstream(gateway, ctx, provider_id, entry)
.await
.ok_or_else(|| {
anyhow::anyhow!("OAGW upstream registration failed for provider '{provider_id}'")
})?;
entry.upstream_alias = Some(upstream.alias.clone());
register_route(gateway, ctx, provider_id, entry, &upstream)
.await
.map_err(|e| {
anyhow::anyhow!("OAGW route registration failed for provider '{provider_id}': {e}")
})?;
let tenant_ids: Vec<String> = entry.tenant_overrides.keys().cloned().collect();
for tenant_id in &tenant_ids {
let tenant_override = &entry.tenant_overrides[tenant_id];
if tenant_override.host.is_none() && tenant_override.upstream_alias.is_none() {
anyhow::bail!(
"provider '{provider_id}': tenant override '{tenant_id}' \
has no host and no upstream_alias - \
cannot create distinct upstream"
);
}
let label = format!("{provider_id}[tenant={tenant_id}]");
let alias = create_tenant_upstream(gateway, ctx, &label, entry, tenant_id)
.await
.ok_or_else(|| {
anyhow::anyhow!(
"OAGW tenant upstream registration failed for provider '{provider_id}', tenant '{tenant_id}'"
)
})?;
if let Some(tenant_override) = entry.tenant_overrides.get_mut(tenant_id) {
tenant_override.upstream_alias = Some(alias);
}
}
}
Ok(())
}
fn endpoint_for(entry: &ProviderEntry) -> oagw_sdk::Endpoint {
use oagw_sdk::{Endpoint, Scheme};
let scheme = if entry.use_http {
Scheme::Http
} else {
Scheme::Https
};
let port = entry.port.unwrap_or(if entry.use_http { 80 } else { 443 });
Endpoint {
scheme,
host: entry.host.clone(),
port,
}
}
async fn create_upstream(
gateway: &Arc<dyn ServiceGatewayClientV1>,
ctx: &modkit_security::SecurityContext,
provider_id: &str,
entry: &ProviderEntry,
) -> Option<oagw_sdk::Upstream> {
use oagw_sdk::{AuthConfig, CreateUpstreamRequest, Server};
let server = Server {
endpoints: vec![endpoint_for(entry)],
};
let mut builder =
CreateUpstreamRequest::builder(server, "gts.x.core.oagw.protocol.v1~x.core.oagw.http.v1")
.enabled(true);
if let Some(alias) = &entry.upstream_alias {
builder = builder.alias(alias);
}
if let (Some(plugin_type), Some(config)) = (&entry.auth_plugin_type, &entry.auth_config) {
builder = builder.auth(AuthConfig {
plugin_type: plugin_type.clone(),
sharing: oagw_sdk::SharingMode::Inherit,
config: Some(config.clone()),
});
}
match gateway.create_upstream(ctx.clone(), builder.build()).await {
Ok(u) => {
info!(
provider_id,
alias = %u.alias,
upstream_id = %u.id,
"OAGW upstream registered"
);
Some(u)
}
Err(e) => {
warn!(
provider_id,
error = %e,
"OAGW upstream registration failed (may already exist)"
);
None
}
}
}
async fn create_tenant_upstream(
gateway: &Arc<dyn ServiceGatewayClientV1>,
ctx: &modkit_security::SecurityContext,
label: &str,
entry: &ProviderEntry,
tenant_id: &str,
) -> Option<String> {
use oagw_sdk::{AuthConfig, CreateUpstreamRequest, Server};
let host = entry.effective_host_for_tenant(tenant_id);
let mut ep = endpoint_for(entry);
host.clone_into(&mut ep.host);
let server = Server {
endpoints: vec![ep],
};
let mut builder =
CreateUpstreamRequest::builder(server, "gts.x.core.oagw.protocol.v1~x.core.oagw.http.v1")
.enabled(true);
if let Some(alias) = entry
.tenant_overrides
.get(tenant_id)
.and_then(|o| o.upstream_alias.as_deref())
{
builder = builder.alias(alias);
}
if let (Some(plugin_type), Some(config)) = (
entry.effective_auth_plugin_type_for_tenant(tenant_id),
entry.effective_auth_config_for_tenant(tenant_id),
) {
builder = builder.auth(AuthConfig {
plugin_type: plugin_type.to_owned(),
sharing: oagw_sdk::SharingMode::Inherit,
config: Some(config.clone()),
});
}
match gateway.create_upstream(ctx.clone(), builder.build()).await {
Ok(u) => {
info!(
label,
alias = %u.alias,
upstream_id = %u.id,
"OAGW tenant upstream registered"
);
Some(u.alias)
}
Err(e) => {
warn!(
label,
error = %e,
"OAGW tenant upstream registration failed (may already exist)"
);
None
}
}
}
async fn register_route(
gateway: &Arc<dyn ServiceGatewayClientV1>,
ctx: &modkit_security::SecurityContext,
provider_id: &str,
entry: &ProviderEntry,
upstream: &oagw_sdk::Upstream,
) -> anyhow::Result<()> {
use oagw_sdk::{CreateRouteRequest, HttpMatch, HttpMethod, MatchRules};
let (route_prefix, suffix_mode) = derive_route_match(&entry.api_path);
let query_allowlist = extract_query_allowlist(&entry.api_path);
let match_rules = MatchRules {
http: Some(HttpMatch {
methods: vec![HttpMethod::Post],
path: route_prefix.clone(),
query_allowlist,
path_suffix_mode: suffix_mode,
}),
grpc: None,
};
let route = gateway
.create_route(
ctx.clone(),
CreateRouteRequest::builder(upstream.id, match_rules)
.enabled(true)
.build(),
)
.await?;
info!(
provider_id,
route_id = %route.id,
route_path = %route_prefix,
"OAGW route registered"
);
register_rag_routes(gateway, ctx, provider_id, entry, upstream).await?;
Ok(())
}
const RAG_ROUTES: &[(&str, &str, bool)] = &[
("POST", "/files", false),
("DELETE", "/files", true),
("POST", "/vector_stores", true),
("DELETE", "/vector_stores", true),
];
#[allow(clippy::cognitive_complexity)]
async fn register_rag_routes(
gateway: &Arc<dyn ServiceGatewayClientV1>,
ctx: &modkit_security::SecurityContext,
provider_id: &str,
entry: &ProviderEntry,
upstream: &oagw_sdk::Upstream,
) -> anyhow::Result<()> {
use oagw_sdk::{CreateRouteRequest, HttpMatch, HttpMethod, MatchRules, PathSuffixMode};
let (prefix, query_allowlist) = match entry.storage_kind {
crate::config::StorageKind::Azure => ("/openai", vec!["api-version".to_owned()]),
crate::config::StorageKind::OpenAi => ("/v1", vec![]),
};
for &(method_str, path_suffix, append_suffix) in RAG_ROUTES {
let method = match method_str {
"POST" => HttpMethod::Post,
"DELETE" => HttpMethod::Delete,
_ => continue,
};
let suffix_mode = if append_suffix {
PathSuffixMode::Append
} else {
PathSuffixMode::Disabled
};
let full_path = format!("{prefix}{path_suffix}");
let match_rules = MatchRules {
http: Some(HttpMatch {
methods: vec![method],
path: full_path.clone(),
query_allowlist: query_allowlist.clone(),
path_suffix_mode: suffix_mode,
}),
grpc: None,
};
match gateway
.create_route(
ctx.clone(),
CreateRouteRequest::builder(upstream.id, match_rules)
.enabled(true)
.build(),
)
.await
{
Ok(route) => {
info!(
provider_id,
route_id = %route.id,
route_path = %full_path,
method = method_str,
"OAGW RAG route registered"
);
}
Err(e) => {
warn!(
provider_id,
error = %e,
route_path = %full_path,
method = method_str,
"OAGW RAG route registration failed (may already exist)"
);
}
}
}
Ok(())
}
fn derive_route_match(api_path: &str) -> (String, oagw_sdk::PathSuffixMode) {
let route_path = api_path
.split('?')
.next()
.unwrap_or(api_path)
.replace("{model}", "*");
let route_prefix = if let Some(pos) = route_path.find('*') {
route_path[..pos].trim_end_matches('/').to_owned()
} else {
route_path.clone()
};
let suffix_mode = if route_path.contains('*') {
oagw_sdk::PathSuffixMode::Append
} else {
oagw_sdk::PathSuffixMode::Disabled
};
(route_prefix, suffix_mode)
}
fn extract_query_allowlist(api_path: &str) -> Vec<String> {
api_path
.split('?')
.nth(1)
.map(|qs| {
qs.split('&')
.filter_map(|pair| pair.split('=').next().map(String::from))
.collect()
})
.unwrap_or_default()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn derive_simple_path() {
let (prefix, mode) = derive_route_match("/v1/responses");
assert_eq!(prefix, "/v1/responses");
assert!(matches!(mode, oagw_sdk::PathSuffixMode::Disabled));
}
#[test]
fn derive_path_with_model_placeholder() {
let (prefix, mode) =
derive_route_match("/openai/deployments/{model}/responses?api-version=2025-03-01");
assert_eq!(prefix, "/openai/deployments");
assert!(matches!(mode, oagw_sdk::PathSuffixMode::Append));
}
#[test]
fn derive_azure_openai_path() {
let (prefix, mode) = derive_route_match("/openai/v1/responses");
assert_eq!(prefix, "/openai/v1/responses");
assert!(matches!(mode, oagw_sdk::PathSuffixMode::Disabled));
}
#[test]
fn extract_empty_query() {
assert!(extract_query_allowlist("/v1/responses").is_empty());
}
#[test]
fn extract_single_query_param() {
let params =
extract_query_allowlist("/openai/deployments/{model}/responses?api-version=2025-03-01");
assert_eq!(params, vec!["api-version"]);
}
#[test]
fn extract_multiple_query_params() {
let params = extract_query_allowlist("/path?foo=1&bar=2&baz=3");
assert_eq!(params, vec!["foo", "bar", "baz"]);
}
#[test]
fn derive_trailing_wildcard_strips_trailing_slash() {
let (prefix, mode) = derive_route_match("/v1/models/*/completions");
assert_eq!(prefix, "/v1/models");
assert!(matches!(mode, oagw_sdk::PathSuffixMode::Append));
}
#[test]
fn derive_root_path() {
let (prefix, mode) = derive_route_match("/");
assert_eq!(prefix, "/");
assert!(matches!(mode, oagw_sdk::PathSuffixMode::Disabled));
}
#[test]
fn derive_query_string_stripped_before_matching() {
let (prefix, mode) = derive_route_match("/v1/responses?stream=true");
assert_eq!(prefix, "/v1/responses");
assert!(matches!(mode, oagw_sdk::PathSuffixMode::Disabled));
}
#[test]
fn extract_query_params_with_empty_values() {
let params = extract_query_allowlist("/path?key=&other=val");
assert_eq!(params, vec!["key", "other"]);
}
}