1use axum::{
8 body::Body,
9 extract::{Extension, Query, State},
10 http::{Request as HttpRequest, StatusCode},
11 middleware::{self, Next},
12 response::{IntoResponse, Response},
13 routing::{get, post},
14 Json, Router,
15};
16use serde::{Deserialize, Serialize};
17use serde_json::Value;
18use std::collections::HashMap;
19use std::net::SocketAddr;
20use std::path::PathBuf;
21use std::sync::Arc;
22
23use crate::core::auth_generator::{AuthCache, GenContext};
24use crate::core::http;
25use crate::core::jwt::{self, JwtConfig, TokenClaims};
26use crate::core::keyring::Keyring;
27use crate::core::manifest::{ManifestRegistry, Provider, Tool};
28use crate::core::mcp_client;
29use crate::core::response;
30use crate::core::scope::ScopeConfig;
31use crate::core::sentry_scope;
32use crate::core::skill::{self, SkillRegistry};
33use crate::core::skillati::{RemoteSkillMeta, SkillAtiClient, SkillAtiError};
34
35pub struct ProxyState {
37 pub registry: ManifestRegistry,
38 pub skill_registry: SkillRegistry,
39 pub keyring: Keyring,
40 pub jwt_config: Option<JwtConfig>,
42 pub jwks_json: Option<Value>,
44 pub auth_cache: AuthCache,
46 pub upstream_url_allowlists:
59 Arc<std::sync::Mutex<std::collections::HashMap<String, Option<Vec<UpstreamAllowEntry>>>>>,
60 pub lazy_schema_cache: LazySchemaCache,
84}
85
86pub type LazySchemaCache = Arc<
92 std::sync::Mutex<std::collections::HashMap<(String, String), Option<HashMap<String, Value>>>>,
93>;
94
95#[derive(Debug, Clone)]
109pub struct UpstreamAllowEntry {
110 pub scheme: String,
112 pub host_label_patterns: Vec<glob::Pattern>,
117 pub path: String,
119}
120
121#[derive(Debug)]
125enum UpstreamOverride {
126 None,
129 Allow(String),
132 Reject(StatusCode, String),
135}
136
137fn resolve_upstream_override(
148 state: &ProxyState,
149 provider: &crate::core::manifest::Provider,
150 header_value: Option<&str>,
151) -> UpstreamOverride {
152 let provider_accepts_override = provider.mcp_url_env.is_some();
153 match (header_value, provider_accepts_override) {
154 (None, _) => UpstreamOverride::None,
155 (Some(_), false) => UpstreamOverride::Reject(
156 StatusCode::BAD_REQUEST,
157 format!(
158 "X-Ati-Upstream-Url sent for provider '{}' which does not declare mcp_url_env",
159 provider.name
160 ),
161 ),
162 (Some(url), true) => {
163 let allowlist_key = format!("{}_allowed_urls", provider.name);
164 let entries = {
168 let mut cache = state.upstream_url_allowlists.lock().unwrap();
169 if !cache.contains_key(&provider.name) {
170 let compiled: Option<Vec<UpstreamAllowEntry>> = state
171 .keyring
172 .get(&allowlist_key)
173 .and_then(|csv| match build_url_allowlist(csv) {
174 Ok(set) => set,
175 Err(e) => {
176 tracing::warn!(
177 provider = %provider.name,
178 error = %e,
179 "failed to compile upstream URL allowlist; treating as missing"
180 );
181 None
182 }
183 });
184 cache.insert(provider.name.clone(), compiled);
185 }
186 cache.get(&provider.name).cloned().flatten()
187 };
188 let Some(entries) = entries else {
189 return UpstreamOverride::Reject(
190 StatusCode::FORBIDDEN,
191 format!(
192 "Provider '{}' has no upstream URL allowlist configured (set ATI_KEY_{}_ALLOWED_URLS on the proxy)",
193 provider.name,
194 provider.name.to_uppercase()
195 ),
196 );
197 };
198 let parsed = match url::Url::parse(url) {
203 Ok(u) => u,
204 Err(e) => {
205 return UpstreamOverride::Reject(
206 StatusCode::BAD_REQUEST,
207 format!("X-Ati-Upstream-Url '{url}' is not a valid URL: {e}"),
208 );
209 }
210 };
211 if matches_allowlist(&parsed, &entries) {
212 UpstreamOverride::Allow(url.to_string())
213 } else {
214 UpstreamOverride::Reject(
215 StatusCode::FORBIDDEN,
216 format!(
217 "Upstream URL '{url}' not in provider '{}'s allowlist",
218 provider.name
219 ),
220 )
221 }
222 }
223 }
224}
225
226fn extract_upstream_url(
237 state: &ProxyState,
238 provider: &Provider,
239 headers: &axum::http::HeaderMap,
240) -> Option<String> {
241 let header_value = headers
242 .get("x-ati-upstream-url")
243 .and_then(|v| v.to_str().ok())
244 .map(|s| s.trim())
245 .filter(|s| !s.is_empty());
246
247 match resolve_upstream_override(state, provider, header_value) {
248 UpstreamOverride::Allow(url) => Some(url),
249 UpstreamOverride::None => None,
250 UpstreamOverride::Reject(status, msg) => {
251 tracing::debug!(
252 provider = %provider.name,
253 status = status.as_u16(),
254 reason = %msg,
255 "lazy schema discovery: upstream URL rejected, falling back to static schema"
256 );
257 None
258 }
259 }
260}
261
262async fn lazy_fetch_schemas(
280 state: &ProxyState,
281 provider: &Provider,
282 upstream_url: &str,
283) -> Option<HashMap<String, Value>> {
284 let key = (provider.name.clone(), upstream_url.to_string());
285
286 {
288 let cache = state.lazy_schema_cache.lock().ok()?;
289 if let Some(cached) = cache.get(&key) {
290 return cached.clone();
291 }
292 }
293
294 let discovered: Option<HashMap<String, Value>> = match mcp_client::McpClient::connect_with_gen(
299 provider,
300 &state.keyring,
301 None,
302 Some(&state.auth_cache),
303 Some(upstream_url),
304 )
305 .await
306 {
307 Ok(client) => {
308 let result = client.list_tools().await;
309 client.disconnect().await;
310 match result {
311 Ok(tools) => {
312 let mut map = HashMap::new();
313 for t in tools {
314 if let Some(schema) = t.input_schema {
315 map.insert(t.name, schema);
316 }
317 }
318 Some(map)
319 }
320 Err(e) => {
321 tracing::warn!(
322 provider = %provider.name,
323 upstream = %upstream_url,
324 error = %e,
325 "lazy schema discovery: tools/list failed"
326 );
327 None
328 }
329 }
330 }
331 Err(e) => {
332 tracing::warn!(
333 provider = %provider.name,
334 upstream = %upstream_url,
335 error = %e,
336 "lazy schema discovery: MCP connect failed"
337 );
338 None
339 }
340 };
341
342 if let Ok(mut cache) = state.lazy_schema_cache.lock() {
347 cache.insert(key, discovered.clone());
348 }
349 discovered
350}
351
352fn matches_allowlist(url: &url::Url, entries: &[UpstreamAllowEntry]) -> bool {
366 if !url.username().is_empty() || url.password().is_some() {
367 return false;
368 }
369 let Some(host) = url.host_str() else {
370 return false;
371 };
372 let host_lower = host.to_ascii_lowercase();
373 let host_labels: Vec<&str> = host_lower.split('.').collect();
374 let path = url.path();
375 for entry in entries {
376 if !entry.scheme.eq_ignore_ascii_case(url.scheme()) {
377 continue;
378 }
379 if entry.host_label_patterns.len() != host_labels.len() {
384 continue;
385 }
386 if !entry
387 .host_label_patterns
388 .iter()
389 .zip(host_labels.iter())
390 .all(|(pat, label)| pat.matches(label))
391 {
392 continue;
393 }
394 if entry.path != path {
395 continue;
396 }
397 return true;
398 }
399 false
400}
401
402fn build_url_allowlist(csv: &str) -> Result<Option<Vec<UpstreamAllowEntry>>, String> {
432 let mut entries = Vec::new();
433 for raw in csv.split(',') {
434 let pat = raw.trim();
435 if pat.is_empty() {
436 continue;
437 }
438 let parsed = url::Url::parse(pat)
439 .map_err(|e| format!("upstream allowlist pattern '{pat}' is not a valid URL: {e}"))?;
440 if !parsed.username().is_empty() || parsed.password().is_some() {
441 return Err(format!(
442 "upstream allowlist pattern '{pat}' must not include userinfo"
443 ));
444 }
445 let host = parsed
446 .host_str()
447 .ok_or_else(|| format!("upstream allowlist pattern '{pat}' has no host"))?;
448 if host.contains("**") {
450 return Err(format!(
451 "upstream allowlist pattern '{pat}' must not contain '**' (use single '*' per DNS label)"
452 ));
453 }
454 let host_lower = host.to_ascii_lowercase();
455 let mut host_label_patterns = Vec::new();
456 for label in host_lower.split('.') {
457 if label.is_empty() {
458 return Err(format!(
459 "upstream allowlist pattern '{pat}' has empty DNS label"
460 ));
461 }
462 let p = glob::Pattern::new(label).map_err(|e| {
463 format!(
464 "upstream allowlist pattern '{pat}' has invalid host label glob '{label}': {e}"
465 )
466 })?;
467 host_label_patterns.push(p);
468 }
469 let path = if parsed.path().is_empty() {
470 "/".to_string()
471 } else {
472 parsed.path().to_string()
473 };
474 entries.push(UpstreamAllowEntry {
475 scheme: parsed.scheme().to_string(),
476 host_label_patterns,
477 path,
478 });
479 }
480 if entries.is_empty() {
481 return Ok(None);
482 }
483 Ok(Some(entries))
484}
485
486#[derive(Debug, Deserialize)]
489pub struct CallRequest {
490 pub tool_name: String,
491 #[serde(default = "default_args")]
495 pub args: Value,
496 #[serde(default)]
499 pub raw_args: Option<Vec<String>>,
500}
501
502fn default_args() -> Value {
503 Value::Object(serde_json::Map::new())
504}
505
506impl CallRequest {
507 fn args_as_map(&self) -> HashMap<String, Value> {
511 match &self.args {
512 Value::Object(map) => map.iter().map(|(k, v)| (k.clone(), v.clone())).collect(),
513 _ => HashMap::new(),
514 }
515 }
516
517 fn args_as_positional(&self) -> Vec<String> {
520 if let Some(ref raw) = self.raw_args {
522 return raw.clone();
523 }
524 match &self.args {
525 Value::Array(arr) => arr
527 .iter()
528 .map(|v| match v {
529 Value::String(s) => s.clone(),
530 other => other.to_string(),
531 })
532 .collect(),
533 Value::String(s) => s.split_whitespace().map(String::from).collect(),
535 Value::Object(map) => {
537 if let Some(Value::Array(pos)) = map.get("_positional") {
538 return pos
539 .iter()
540 .map(|v| match v {
541 Value::String(s) => s.clone(),
542 other => other.to_string(),
543 })
544 .collect();
545 }
546 let mut result = Vec::new();
548 for (k, v) in map {
549 result.push(format!("--{k}"));
550 match v {
551 Value::String(s) => result.push(s.clone()),
552 Value::Bool(true) => {} other => result.push(other.to_string()),
554 }
555 }
556 result
557 }
558 _ => Vec::new(),
559 }
560 }
561}
562
563#[derive(Debug, Serialize)]
564pub struct CallResponse {
565 pub result: Value,
566 #[serde(skip_serializing_if = "Option::is_none")]
567 pub error: Option<String>,
568}
569
570#[derive(Debug, Deserialize)]
571pub struct HelpRequest {
572 pub query: String,
573 #[serde(default)]
574 pub tool: Option<String>,
575}
576
577#[derive(Debug, Serialize)]
578pub struct HelpResponse {
579 pub content: String,
580 #[serde(skip_serializing_if = "Option::is_none")]
581 pub error: Option<String>,
582}
583
584#[derive(Debug, Serialize)]
585pub struct HealthResponse {
586 pub status: String,
587 pub version: String,
588 pub tools: usize,
589 pub providers: usize,
590 pub skills: usize,
591 pub auth: String,
592}
593
594#[derive(Debug, Deserialize)]
597pub struct SkillsQuery {
598 #[serde(default)]
599 pub category: Option<String>,
600 #[serde(default)]
601 pub provider: Option<String>,
602 #[serde(default)]
603 pub tool: Option<String>,
604 #[serde(default)]
605 pub search: Option<String>,
606}
607
608#[derive(Debug, Deserialize)]
609pub struct SkillDetailQuery {
610 #[serde(default)]
611 pub meta: Option<bool>,
612 #[serde(default)]
613 pub refs: Option<bool>,
614}
615
616#[derive(Debug, Deserialize)]
617pub struct SkillResolveRequest {
618 pub scopes: Vec<String>,
619 #[serde(default)]
621 pub include_content: bool,
622}
623
624#[derive(Debug, Deserialize)]
625pub struct SkillBundleBatchRequest {
626 pub names: Vec<String>,
627}
628
629#[derive(Debug, Deserialize, Default)]
630pub struct SkillAtiCatalogQuery {
631 #[serde(default)]
632 pub search: Option<String>,
633}
634
635#[derive(Debug, Deserialize, Default)]
636pub struct SkillAtiResourcesQuery {
637 #[serde(default)]
638 pub prefix: Option<String>,
639}
640
641#[derive(Debug, Deserialize)]
642pub struct SkillAtiFileQuery {
643 pub path: String,
644}
645
646#[derive(Debug, Deserialize)]
649pub struct ToolsQuery {
650 #[serde(default)]
651 pub provider: Option<String>,
652 #[serde(default)]
653 pub search: Option<String>,
654}
655
656fn scopes_for_request(claims: Option<&TokenClaims>, state: &ProxyState) -> ScopeConfig {
659 match claims {
660 Some(claims) => ScopeConfig::from_jwt(claims),
661 None if state.jwt_config.is_none() => ScopeConfig::unrestricted(),
662 None => ScopeConfig {
663 scopes: Vec::new(),
664 sub: String::new(),
665 expires_at: 0,
666 rate_config: None,
667 },
668 }
669}
670
671fn visible_tools_for_scopes<'a>(
672 state: &'a ProxyState,
673 scopes: &ScopeConfig,
674) -> Vec<(&'a Provider, &'a Tool)> {
675 crate::core::scope::filter_tools_by_scope(state.registry.list_public_tools(), scopes)
676}
677
678fn visible_skill_names(
679 state: &ProxyState,
680 scopes: &ScopeConfig,
681) -> std::collections::HashSet<String> {
682 skill::visible_skills(&state.skill_registry, &state.registry, scopes)
683 .into_iter()
684 .map(|skill| skill.name.clone())
685 .collect()
686}
687
688fn visible_remote_skill_names(
700 state: &ProxyState,
701 scopes: &ScopeConfig,
702 catalog: &[RemoteSkillMeta],
703) -> std::collections::HashSet<String> {
704 let mut visible: std::collections::HashSet<String> = std::collections::HashSet::new();
705 if catalog.is_empty() {
706 return visible;
707 }
708 if scopes.is_wildcard() {
709 for entry in catalog {
710 visible.insert(entry.name.clone());
711 }
712 return visible;
713 }
714
715 let allowed_tool_pairs: Vec<(String, String)> =
719 crate::core::scope::filter_tools_by_scope(state.registry.list_public_tools(), scopes)
720 .into_iter()
721 .map(|(p, t)| (p.name.clone(), t.name.clone()))
722 .collect();
723 let allowed_tool_names: std::collections::HashSet<&str> =
724 allowed_tool_pairs.iter().map(|(_, t)| t.as_str()).collect();
725 let allowed_provider_names: std::collections::HashSet<&str> =
726 allowed_tool_pairs.iter().map(|(p, _)| p.as_str()).collect();
727 let allowed_categories: std::collections::HashSet<String> = state
728 .registry
729 .list_providers()
730 .into_iter()
731 .filter(|p| allowed_provider_names.contains(p.name.as_str()))
732 .filter_map(|p| p.category.clone())
733 .collect();
734
735 for scope in &scopes.scopes {
737 if let Some(skill_name) = scope.strip_prefix("skill:") {
738 if catalog.iter().any(|e| e.name == skill_name) {
739 visible.insert(skill_name.to_string());
740 }
741 }
742 }
743
744 for entry in catalog {
748 if entry
749 .tools
750 .iter()
751 .any(|t| allowed_tool_names.contains(t.as_str()))
752 || entry
753 .providers
754 .iter()
755 .any(|p| allowed_provider_names.contains(p.as_str()))
756 || entry
757 .categories
758 .iter()
759 .any(|c| allowed_categories.contains(c))
760 {
761 visible.insert(entry.name.clone());
762 }
763 }
764
765 visible
766}
767
768async fn visible_skill_names_with_remote(
772 state: &ProxyState,
773 scopes: &ScopeConfig,
774 client: &SkillAtiClient,
775) -> Result<std::collections::HashSet<String>, SkillAtiError> {
776 let mut names = visible_skill_names(state, scopes);
777 let catalog = client.catalog().await?;
778 let remote = visible_remote_skill_names(state, scopes, &catalog);
779 names.extend(remote);
780 Ok(names)
781}
782
783async fn handle_call(
784 State(state): State<Arc<ProxyState>>,
785 req: HttpRequest<Body>,
786) -> impl IntoResponse {
787 let claims = req.extensions().get::<TokenClaims>().cloned();
789 if let Some(ref c) = claims {
796 sentry_scope::set_jwt_sentry_scope(c);
797 }
798 let bearer_token: String = req
802 .extensions()
803 .get::<BearerToken>()
804 .map(|b| b.0.clone())
805 .unwrap_or_default();
806
807 let upstream_url_header: Option<String> = req
812 .headers()
813 .get("x-ati-upstream-url")
814 .and_then(|v| v.to_str().ok())
815 .map(|s| s.trim().to_string())
816 .filter(|s| !s.is_empty());
817
818 let body_bytes = match axum::body::to_bytes(req.into_body(), max_call_body_bytes()).await {
825 Ok(b) => b,
826 Err(e) => {
827 return (
828 StatusCode::BAD_REQUEST,
829 Json(CallResponse {
830 result: Value::Null,
831 error: Some(format!("Failed to read request body: {e}")),
832 }),
833 );
834 }
835 };
836
837 let call_req: CallRequest = match serde_json::from_slice(&body_bytes) {
838 Ok(r) => r,
839 Err(e) => {
840 return (
841 StatusCode::UNPROCESSABLE_ENTITY,
842 Json(CallResponse {
843 result: Value::Null,
844 error: Some(format!("Invalid request: {e}")),
845 }),
846 );
847 }
848 };
849
850 tracing::debug!(
851 tool = %call_req.tool_name,
852 args = ?call_req.args,
853 "POST /call"
854 );
855
856 let (provider, tool) = match state.registry.get_tool(&call_req.tool_name) {
859 Some(pt) => pt,
860 None => {
861 let mut resolved = None;
865 for (idx, _) in call_req.tool_name.match_indices('_') {
866 let candidate = format!(
867 "{}:{}",
868 &call_req.tool_name[..idx],
869 &call_req.tool_name[idx + 1..]
870 );
871 if let Some(pt) = state.registry.get_tool(&candidate) {
872 tracing::debug!(
873 original = %call_req.tool_name,
874 resolved = %candidate,
875 "resolved underscore tool name to colon format"
876 );
877 resolved = Some(pt);
878 break;
879 }
880 }
881
882 match resolved {
883 Some(pt) => pt,
884 None => {
885 return (
886 StatusCode::NOT_FOUND,
887 Json(CallResponse {
888 result: Value::Null,
889 error: Some(format!("Unknown tool: '{}'", call_req.tool_name)),
890 }),
891 );
892 }
893 }
894 }
895 };
896
897 if let Some(tool_scope) = &tool.scope {
899 let scopes = match &claims {
900 Some(c) => ScopeConfig::from_jwt(c),
901 None if state.jwt_config.is_none() => ScopeConfig::unrestricted(), None => {
903 return (
904 StatusCode::FORBIDDEN,
905 Json(CallResponse {
906 result: Value::Null,
907 error: Some("Authentication required — no JWT provided".into()),
908 }),
909 );
910 }
911 };
912
913 if !scopes.is_allowed(tool_scope) {
914 return (
915 StatusCode::FORBIDDEN,
916 Json(CallResponse {
917 result: Value::Null,
918 error: Some(format!(
919 "Access denied: '{}' is not in your scopes",
920 tool.name
921 )),
922 }),
923 );
924 }
925 }
926
927 {
929 let scopes = match &claims {
930 Some(c) => ScopeConfig::from_jwt(c),
931 None => ScopeConfig::unrestricted(),
932 };
933 if let Some(ref rate_config) = scopes.rate_config {
934 if let Err(e) = crate::core::rate::check_and_record(&call_req.tool_name, rate_config) {
935 return (
936 StatusCode::TOO_MANY_REQUESTS,
937 Json(CallResponse {
938 result: Value::Null,
939 error: Some(format!("{e}")),
940 }),
941 );
942 }
943 }
944 }
945
946 let gen_ctx = GenContext {
948 jwt_sub: claims
949 .as_ref()
950 .map(|c| c.sub.clone())
951 .unwrap_or_else(|| "dev".into()),
952 jwt_scope: claims
953 .as_ref()
954 .map(|c| c.scope.clone())
955 .unwrap_or_else(|| "*".into()),
956 tool_name: call_req.tool_name.clone(),
957 timestamp: crate::core::jwt::now_secs(),
958 jwt_token: bearer_token.clone(),
959 };
960
961 let override_mcp_url: Option<String> =
965 match resolve_upstream_override(&state, provider, upstream_url_header.as_deref()) {
966 UpstreamOverride::None => None,
967 UpstreamOverride::Allow(url) => Some(url),
968 UpstreamOverride::Reject(status, msg) => {
969 tracing::warn!(
970 provider = %provider.name,
971 tool = %call_req.tool_name,
972 status = status.as_u16(),
973 reason = %msg,
974 "rejecting sandbox-supplied upstream URL"
975 );
976 return (
977 status,
978 Json(CallResponse {
979 result: Value::Null,
980 error: Some(msg),
981 }),
982 );
983 }
984 };
985
986 let agent_sub = claims.as_ref().map(|c| c.sub.clone()).unwrap_or_default();
988 let job_id = claims
989 .as_ref()
990 .and_then(|c| c.job_id.clone())
991 .unwrap_or_default();
992 let sandbox_id = claims
993 .as_ref()
994 .and_then(|c| c.sandbox_id.clone())
995 .unwrap_or_default();
996 tracing::info!(
997 tool = %call_req.tool_name,
998 agent = %agent_sub,
999 job_id = %job_id,
1000 sandbox_id = %sandbox_id,
1001 "tool call"
1002 );
1003 let start = std::time::Instant::now();
1004
1005 let response = match provider.handler.as_str() {
1006 "mcp" => {
1007 let args_map = call_req.args_as_map();
1008 match mcp_client::execute_with_gen(
1009 provider,
1010 &call_req.tool_name,
1011 &args_map,
1012 &state.keyring,
1013 Some(&gen_ctx),
1014 Some(&state.auth_cache),
1015 override_mcp_url.as_deref(),
1016 )
1017 .await
1018 {
1019 Ok(result) => (
1020 StatusCode::OK,
1021 Json(CallResponse {
1022 result,
1023 error: None,
1024 }),
1025 ),
1026 Err(e) => {
1027 let (provider_name, operation_id) =
1035 sentry_scope::provider_and_op(&provider.name, &call_req.tool_name);
1036 let msg = e.to_string();
1037 sentry_scope::report_upstream_error(
1038 &provider_name,
1039 &operation_id,
1040 0,
1041 502,
1042 None,
1043 Some(&msg),
1044 );
1045 sentry_scope::capture_error_with_scope(
1046 &e,
1047 &provider_name,
1048 &operation_id,
1049 0,
1050 502,
1051 None,
1052 Some(&msg),
1053 );
1054 (
1055 StatusCode::BAD_GATEWAY,
1056 Json(CallResponse {
1057 result: Value::Null,
1058 error: Some(format!("MCP error: {e}")),
1059 }),
1060 )
1061 }
1062 }
1063 }
1064 "cli" => {
1065 let positional = call_req.args_as_positional();
1066 match crate::core::cli_executor::execute_with_gen(
1067 provider,
1068 &positional,
1069 &state.keyring,
1070 Some(&gen_ctx),
1071 Some(&state.auth_cache),
1072 )
1073 .await
1074 {
1075 Ok(result) => (
1076 StatusCode::OK,
1077 Json(CallResponse {
1078 result,
1079 error: None,
1080 }),
1081 ),
1082 Err(e) => {
1083 let (provider_name, operation_id) =
1088 sentry_scope::provider_and_op(&provider.name, &call_req.tool_name);
1089 let msg = e.to_string();
1090 sentry_scope::report_upstream_error(
1091 &provider_name,
1092 &operation_id,
1093 0,
1094 502,
1095 None,
1096 Some(&msg),
1097 );
1098 sentry_scope::capture_error_with_scope(
1099 &e,
1100 &provider_name,
1101 &operation_id,
1102 0,
1103 502,
1104 None,
1105 Some(&msg),
1106 );
1107 (
1108 StatusCode::BAD_GATEWAY,
1109 Json(CallResponse {
1110 result: Value::Null,
1111 error: Some(format!("CLI error: {e}")),
1112 }),
1113 )
1114 }
1115 }
1116 }
1117 "file_manager" => {
1118 let args_map = call_req.args_as_map();
1119 match dispatch_file_manager(&call_req.tool_name, &args_map, provider, &state.keyring)
1120 .await
1121 {
1122 Ok(result) => (
1123 StatusCode::OK,
1124 Json(CallResponse {
1125 result,
1126 error: None,
1127 }),
1128 ),
1129 Err((status, msg)) => (
1130 status,
1131 Json(CallResponse {
1132 result: Value::Null,
1133 error: Some(msg),
1134 }),
1135 ),
1136 }
1137 }
1138 _ => {
1139 let args_map = call_req.args_as_map();
1140 let raw_response = match http::execute_tool_with_gen(
1141 provider,
1142 tool,
1143 &args_map,
1144 &state.keyring,
1145 Some(&gen_ctx),
1146 Some(&state.auth_cache),
1147 )
1148 .await
1149 {
1150 Ok(resp) => resp,
1151 Err(http::HttpError::NoRecordsFound { status }) => {
1152 let duration = start.elapsed();
1156 tracing::info!(
1157 tool = %call_req.tool_name,
1158 upstream_status = status,
1159 "upstream returned no records"
1160 );
1161 write_proxy_audit(&call_req, &agent_sub, claims.as_ref(), duration, None);
1162 return (
1163 StatusCode::OK,
1164 Json(CallResponse {
1165 result: serde_json::json!({ "records": [] }),
1166 error: None,
1167 }),
1168 );
1169 }
1170 Err(e) => {
1171 let duration = start.elapsed();
1172 let (provider_name, operation_id) =
1178 sentry_scope::provider_and_op(&provider.name, &call_req.tool_name);
1179 let (upstream_status, error_type, error_message) = match &e {
1180 http::HttpError::ApiError {
1181 status,
1182 error_type,
1183 error_message,
1184 ..
1185 } => (*status, error_type.clone(), error_message.clone()),
1186 _ => (0u16, None, Some(e.to_string())),
1187 };
1188 sentry_scope::report_upstream_error(
1189 &provider_name,
1190 &operation_id,
1191 upstream_status,
1192 502,
1193 error_type.as_deref(),
1194 error_message.as_deref(),
1195 );
1196 sentry_scope::capture_error_with_scope(
1202 &e,
1203 &provider_name,
1204 &operation_id,
1205 upstream_status,
1206 502,
1207 error_type.as_deref(),
1208 error_message.as_deref(),
1209 );
1210 write_proxy_audit(
1211 &call_req,
1212 &agent_sub,
1213 claims.as_ref(),
1214 duration,
1215 Some(&e.to_string()),
1216 );
1217 return (
1218 StatusCode::BAD_GATEWAY,
1219 Json(CallResponse {
1220 result: Value::Null,
1221 error: Some(format!("Upstream API error: {e}")),
1222 }),
1223 );
1224 }
1225 };
1226
1227 let processed = match response::process_response(&raw_response, tool.response.as_ref())
1228 {
1229 Ok(p) => p,
1230 Err(e) => {
1231 let duration = start.elapsed();
1232 write_proxy_audit(
1233 &call_req,
1234 &agent_sub,
1235 claims.as_ref(),
1236 duration,
1237 Some(&e.to_string()),
1238 );
1239 return (
1240 StatusCode::INTERNAL_SERVER_ERROR,
1241 Json(CallResponse {
1242 result: raw_response,
1243 error: Some(format!("Response processing error: {e}")),
1244 }),
1245 );
1246 }
1247 };
1248
1249 (
1250 StatusCode::OK,
1251 Json(CallResponse {
1252 result: processed,
1253 error: None,
1254 }),
1255 )
1256 }
1257 };
1258
1259 let duration = start.elapsed();
1260 let error_msg = response.1.error.as_deref();
1261 write_proxy_audit(&call_req, &agent_sub, claims.as_ref(), duration, error_msg);
1262
1263 response
1264}
1265
1266async fn handle_help(
1267 State(state): State<Arc<ProxyState>>,
1268 claims: Option<Extension<TokenClaims>>,
1269 Json(req): Json<HelpRequest>,
1270) -> impl IntoResponse {
1271 tracing::debug!(query = %req.query, tool = ?req.tool, "POST /help");
1272
1273 let claims = claims.map(|Extension(claims)| claims);
1274 let scopes = scopes_for_request(claims.as_ref(), &state);
1275
1276 let (llm_provider, llm_tool) = match state.registry.get_tool("_chat_completion") {
1277 Some(pt) => pt,
1278 None => {
1279 return (
1280 StatusCode::SERVICE_UNAVAILABLE,
1281 Json(HelpResponse {
1282 content: String::new(),
1283 error: Some("No _llm.toml manifest found. Proxy help requires a configured LLM provider.".into()),
1284 }),
1285 );
1286 }
1287 };
1288
1289 let api_key = match llm_provider
1290 .auth_key_name
1291 .as_deref()
1292 .and_then(|k| state.keyring.get(k))
1293 {
1294 Some(key) => key.to_string(),
1295 None => {
1296 return (
1297 StatusCode::SERVICE_UNAVAILABLE,
1298 Json(HelpResponse {
1299 content: String::new(),
1300 error: Some("LLM API key not found in keyring".into()),
1301 }),
1302 );
1303 }
1304 };
1305
1306 let resolved_skills = skill::resolve_skills(&state.skill_registry, &state.registry, &scopes);
1307 let local_skills_section = if resolved_skills.is_empty() {
1308 String::new()
1309 } else {
1310 format!(
1311 "## Available Skills (methodology guides)\n{}",
1312 skill::build_skill_context(&resolved_skills)
1313 )
1314 };
1315 let remote_query = req
1316 .tool
1317 .as_ref()
1318 .map(|tool| format!("{tool} {}", req.query))
1319 .unwrap_or_else(|| req.query.clone());
1320 let remote_skills_section =
1321 build_remote_skillati_section(&state.keyring, &remote_query, 12).await;
1322 let skills_section = merge_help_skill_sections(&[local_skills_section, remote_skills_section]);
1323
1324 let visible_tools = visible_tools_for_scopes(&state, &scopes);
1326 let system_prompt = if let Some(ref tool_name) = req.tool {
1327 match build_scoped_prompt(tool_name, &visible_tools, &skills_section) {
1329 Some(prompt) => prompt,
1330 None => {
1331 return (
1332 StatusCode::FORBIDDEN,
1333 Json(HelpResponse {
1334 content: String::new(),
1335 error: Some(format!(
1336 "Scope '{tool_name}' is not visible in your current scopes."
1337 )),
1338 }),
1339 );
1340 }
1341 }
1342 } else {
1343 let tools_context = build_tool_context(&visible_tools);
1344 HELP_SYSTEM_PROMPT
1345 .replace("{tools}", &tools_context)
1346 .replace("{skills_section}", &skills_section)
1347 };
1348
1349 let request_body = serde_json::json!({
1350 "model": "zai-glm-4.7",
1351 "messages": [
1352 {"role": "system", "content": system_prompt},
1353 {"role": "user", "content": req.query}
1354 ],
1355 "max_completion_tokens": 1536,
1356 "temperature": 0.3
1357 });
1358
1359 let client = reqwest::Client::new();
1360 let url = format!(
1361 "{}{}",
1362 llm_provider.base_url.trim_end_matches('/'),
1363 llm_tool.endpoint
1364 );
1365
1366 let response = match client
1367 .post(&url)
1368 .bearer_auth(&api_key)
1369 .json(&request_body)
1370 .send()
1371 .await
1372 {
1373 Ok(r) => r,
1374 Err(e) => {
1375 return (
1376 StatusCode::BAD_GATEWAY,
1377 Json(HelpResponse {
1378 content: String::new(),
1379 error: Some(format!("LLM request failed: {e}")),
1380 }),
1381 );
1382 }
1383 };
1384
1385 if !response.status().is_success() {
1386 let status = response.status();
1387 let body = response.text().await.unwrap_or_default();
1388 return (
1389 StatusCode::BAD_GATEWAY,
1390 Json(HelpResponse {
1391 content: String::new(),
1392 error: Some(format!("LLM API error ({status}): {body}")),
1393 }),
1394 );
1395 }
1396
1397 let body: Value = match response.json().await {
1398 Ok(b) => b,
1399 Err(e) => {
1400 return (
1401 StatusCode::INTERNAL_SERVER_ERROR,
1402 Json(HelpResponse {
1403 content: String::new(),
1404 error: Some(format!("Failed to parse LLM response: {e}")),
1405 }),
1406 );
1407 }
1408 };
1409
1410 let content = body
1411 .pointer("/choices/0/message/content")
1412 .and_then(|c| c.as_str())
1413 .unwrap_or("No response from LLM")
1414 .to_string();
1415
1416 (
1417 StatusCode::OK,
1418 Json(HelpResponse {
1419 content,
1420 error: None,
1421 }),
1422 )
1423}
1424
1425async fn handle_health(State(state): State<Arc<ProxyState>>) -> impl IntoResponse {
1426 let auth = if state.jwt_config.is_some() {
1427 "jwt"
1428 } else {
1429 "disabled"
1430 };
1431
1432 Json(HealthResponse {
1433 status: "ok".into(),
1434 version: env!("CARGO_PKG_VERSION").into(),
1435 tools: state.registry.list_public_tools().len(),
1436 providers: state.registry.list_providers().len(),
1437 skills: state.skill_registry.skill_count(),
1438 auth: auth.into(),
1439 })
1440}
1441
1442async fn handle_jwks(State(state): State<Arc<ProxyState>>) -> impl IntoResponse {
1444 match &state.jwks_json {
1445 Some(jwks) => (StatusCode::OK, Json(jwks.clone())),
1446 None => (
1447 StatusCode::NOT_FOUND,
1448 Json(serde_json::json!({"error": "JWKS not configured"})),
1449 ),
1450 }
1451}
1452
1453async fn handle_mcp(
1458 State(state): State<Arc<ProxyState>>,
1459 claims: Option<Extension<TokenClaims>>,
1460 bearer: Option<Extension<BearerToken>>,
1461 headers: axum::http::HeaderMap,
1462 Json(msg): Json<Value>,
1463) -> impl IntoResponse {
1464 let claims = claims.map(|Extension(claims)| claims);
1465 let bearer_token: String = bearer.map(|Extension(b)| b.0).unwrap_or_default();
1469 let upstream_url_header: Option<String> = headers
1472 .get("x-ati-upstream-url")
1473 .and_then(|v| v.to_str().ok())
1474 .map(|s| s.trim().to_string())
1475 .filter(|s| !s.is_empty());
1476 let scopes = scopes_for_request(claims.as_ref(), &state);
1477 let method = msg.get("method").and_then(|m| m.as_str()).unwrap_or("");
1478 let id = msg.get("id").cloned();
1479 tracing::info!(
1480 %method,
1481 agent = claims.as_ref().map(|c| c.sub.as_str()).unwrap_or(""),
1482 job_id = claims.as_ref().and_then(|c| c.job_id.as_deref()).unwrap_or(""),
1483 sandbox_id = claims.as_ref().and_then(|c| c.sandbox_id.as_deref()).unwrap_or(""),
1484 "mcp call"
1485 );
1486
1487 match method {
1488 "initialize" => {
1489 let result = serde_json::json!({
1490 "protocolVersion": "2025-03-26",
1491 "capabilities": {
1492 "tools": { "listChanged": false }
1493 },
1494 "serverInfo": {
1495 "name": "ati-proxy",
1496 "version": env!("CARGO_PKG_VERSION")
1497 }
1498 });
1499 jsonrpc_success(id, result)
1500 }
1501
1502 "notifications/initialized" => (StatusCode::ACCEPTED, Json(Value::Null)),
1503
1504 "tools/list" => {
1505 let visible_tools = visible_tools_for_scopes(&state, &scopes);
1506
1507 let mut schemas_by_provider: HashMap<String, HashMap<String, Value>> = HashMap::new();
1516 for (provider, tool) in visible_tools.iter() {
1517 if tool.input_schema.is_some() {
1518 continue;
1519 }
1520 if provider.handler != "mcp" || provider.mcp_url_env.is_none() {
1521 continue;
1522 }
1523 if schemas_by_provider.contains_key(&provider.name) {
1524 continue;
1525 }
1526 let map = match extract_upstream_url(&state, provider, &headers) {
1527 Some(url) => lazy_fetch_schemas(&state, provider, &url)
1528 .await
1529 .unwrap_or_default(),
1530 None => HashMap::new(),
1531 };
1532 schemas_by_provider.insert(provider.name.clone(), map);
1533 }
1534
1535 let mcp_tools: Vec<Value> = visible_tools
1536 .iter()
1537 .map(|(provider, tool)| {
1538 let schema = tool.input_schema.clone().or_else(|| {
1539 schemas_by_provider
1540 .get(&provider.name)
1541 .and_then(|m| m.get(&tool.name).cloned())
1542 });
1543 serde_json::json!({
1544 "name": tool.name,
1545 "description": tool.description,
1546 "inputSchema": schema.unwrap_or(serde_json::json!({
1547 "type": "object",
1548 "properties": {}
1549 }))
1550 })
1551 })
1552 .collect();
1553
1554 let result = serde_json::json!({
1555 "tools": mcp_tools,
1556 });
1557 jsonrpc_success(id, result)
1558 }
1559
1560 "tools/call" => {
1561 let params = msg.get("params").cloned().unwrap_or(Value::Null);
1562 let tool_name = params.get("name").and_then(|n| n.as_str()).unwrap_or("");
1563 let arguments: HashMap<String, Value> = params
1564 .get("arguments")
1565 .and_then(|a| serde_json::from_value(a.clone()).ok())
1566 .unwrap_or_default();
1567
1568 if tool_name.is_empty() {
1569 return jsonrpc_error(id, -32602, "Missing tool name in params.name");
1570 }
1571
1572 let (provider, _tool) = match state.registry.get_tool(tool_name) {
1573 Some(pt) => pt,
1574 None => {
1575 return jsonrpc_error(id, -32602, &format!("Unknown tool: '{tool_name}'"));
1576 }
1577 };
1578
1579 if let Some(tool_scope) = &_tool.scope {
1580 if !scopes.is_allowed(tool_scope) {
1581 return jsonrpc_error(
1582 id,
1583 -32001,
1584 &format!("Access denied: '{}' is not in your scopes", _tool.name),
1585 );
1586 }
1587 }
1588
1589 tracing::debug!(%tool_name, provider = %provider.name, "MCP tools/call");
1590
1591 let override_mcp_url: Option<String> =
1597 match resolve_upstream_override(&state, provider, upstream_url_header.as_deref()) {
1598 UpstreamOverride::None => None,
1599 UpstreamOverride::Allow(url) => Some(url),
1600 UpstreamOverride::Reject(status, msg) => {
1601 let code = if status == StatusCode::BAD_REQUEST {
1602 -32602
1603 } else {
1604 -32001
1605 };
1606 tracing::warn!(
1607 provider = %provider.name,
1608 tool = %tool_name,
1609 status = status.as_u16(),
1610 reason = %msg,
1611 "rejecting sandbox-supplied upstream URL on /mcp"
1612 );
1613 return jsonrpc_error(id, code, &msg);
1614 }
1615 };
1616
1617 let mcp_gen_ctx = GenContext {
1618 jwt_sub: claims
1619 .as_ref()
1620 .map(|claims| claims.sub.clone())
1621 .unwrap_or_else(|| "dev".into()),
1622 jwt_scope: claims
1623 .as_ref()
1624 .map(|claims| claims.scope.clone())
1625 .unwrap_or_else(|| "*".into()),
1626 tool_name: tool_name.to_string(),
1627 timestamp: crate::core::jwt::now_secs(),
1628 jwt_token: bearer_token.clone(),
1629 };
1630
1631 let result = if provider.is_mcp() {
1632 mcp_client::execute_with_gen(
1633 provider,
1634 tool_name,
1635 &arguments,
1636 &state.keyring,
1637 Some(&mcp_gen_ctx),
1638 Some(&state.auth_cache),
1639 override_mcp_url.as_deref(),
1640 )
1641 .await
1642 } else if provider.is_cli() {
1643 let raw: Vec<String> = arguments
1645 .iter()
1646 .flat_map(|(k, v)| {
1647 let val = match v {
1648 Value::String(s) => s.clone(),
1649 other => other.to_string(),
1650 };
1651 vec![format!("--{k}"), val]
1652 })
1653 .collect();
1654 crate::core::cli_executor::execute_with_gen(
1655 provider,
1656 &raw,
1657 &state.keyring,
1658 Some(&mcp_gen_ctx),
1659 Some(&state.auth_cache),
1660 )
1661 .await
1662 .map_err(|e| mcp_client::McpError::Transport(e.to_string()))
1663 } else {
1664 match http::execute_tool_with_gen(
1665 provider,
1666 _tool,
1667 &arguments,
1668 &state.keyring,
1669 Some(&mcp_gen_ctx),
1670 Some(&state.auth_cache),
1671 )
1672 .await
1673 {
1674 Ok(val) => Ok(val),
1675 Err(e) => Err(mcp_client::McpError::Transport(e.to_string())),
1676 }
1677 };
1678
1679 match result {
1680 Ok(value) => {
1681 let text = match &value {
1682 Value::String(s) => s.clone(),
1683 other => serde_json::to_string_pretty(other).unwrap_or_default(),
1684 };
1685 let mcp_result = serde_json::json!({
1686 "content": [{"type": "text", "text": text}],
1687 "isError": false,
1688 });
1689 jsonrpc_success(id, mcp_result)
1690 }
1691 Err(e) => {
1692 let mcp_result = serde_json::json!({
1693 "content": [{"type": "text", "text": format!("Error: {e}")}],
1694 "isError": true,
1695 });
1696 jsonrpc_success(id, mcp_result)
1697 }
1698 }
1699 }
1700
1701 _ => jsonrpc_error(id, -32601, &format!("Method not found: '{method}'")),
1702 }
1703}
1704
1705fn jsonrpc_success(id: Option<Value>, result: Value) -> (StatusCode, Json<Value>) {
1706 (
1707 StatusCode::OK,
1708 Json(serde_json::json!({
1709 "jsonrpc": "2.0",
1710 "id": id,
1711 "result": result,
1712 })),
1713 )
1714}
1715
1716fn jsonrpc_error(id: Option<Value>, code: i64, message: &str) -> (StatusCode, Json<Value>) {
1717 (
1718 StatusCode::OK,
1719 Json(serde_json::json!({
1720 "jsonrpc": "2.0",
1721 "id": id,
1722 "error": {
1723 "code": code,
1724 "message": message,
1725 }
1726 })),
1727 )
1728}
1729
1730async fn handle_tools_list(
1736 State(state): State<Arc<ProxyState>>,
1737 claims: Option<Extension<TokenClaims>>,
1738 headers: axum::http::HeaderMap,
1739 axum::extract::Query(query): axum::extract::Query<ToolsQuery>,
1740) -> impl IntoResponse {
1741 tracing::debug!(
1742 provider = ?query.provider,
1743 search = ?query.search,
1744 "GET /tools"
1745 );
1746
1747 let claims = claims.map(|Extension(claims)| claims);
1748 let scopes = scopes_for_request(claims.as_ref(), &state);
1749 let all_tools = visible_tools_for_scopes(&state, &scopes);
1750
1751 let filtered: Vec<&(&Provider, &Tool)> = all_tools
1752 .iter()
1753 .filter(|(provider, tool)| {
1754 if let Some(ref p) = query.provider {
1755 if provider.name != *p {
1756 return false;
1757 }
1758 }
1759 if let Some(ref q) = query.search {
1760 let q = q.to_lowercase();
1761 let name_match = tool.name.to_lowercase().contains(&q);
1762 let desc_match = tool.description.to_lowercase().contains(&q);
1763 let tag_match = tool.tags.iter().any(|t| t.to_lowercase().contains(&q));
1764 if !name_match && !desc_match && !tag_match {
1765 return false;
1766 }
1767 }
1768 true
1769 })
1770 .collect();
1771
1772 let mut schemas_by_provider: HashMap<String, HashMap<String, Value>> = HashMap::new();
1786 for (provider, tool) in filtered.iter().copied() {
1787 if tool.input_schema.is_some() {
1788 continue;
1789 }
1790 if provider.handler != "mcp" || provider.mcp_url_env.is_none() {
1791 continue;
1792 }
1793 if schemas_by_provider.contains_key(&provider.name) {
1794 continue;
1795 }
1796 let map = match extract_upstream_url(&state, provider, &headers) {
1797 Some(upstream_url) => lazy_fetch_schemas(&state, provider, &upstream_url)
1798 .await
1799 .unwrap_or_default(),
1800 None => HashMap::new(),
1801 };
1802 schemas_by_provider.insert(provider.name.clone(), map);
1803 }
1804
1805 let tools: Vec<Value> = filtered
1806 .into_iter()
1807 .map(|(provider, tool)| {
1808 let input_schema = tool.input_schema.clone().or_else(|| {
1809 schemas_by_provider
1810 .get(&provider.name)
1811 .and_then(|m| m.get(&tool.name).cloned())
1812 });
1813 serde_json::json!({
1814 "name": tool.name,
1815 "description": tool.description,
1816 "provider": provider.name,
1817 "method": format!("{:?}", tool.method),
1818 "tags": tool.tags,
1819 "skills": provider.skills,
1820 "input_schema": input_schema,
1821 })
1822 })
1823 .collect();
1824
1825 (StatusCode::OK, Json(Value::Array(tools)))
1826}
1827
1828async fn handle_tool_info(
1830 State(state): State<Arc<ProxyState>>,
1831 claims: Option<Extension<TokenClaims>>,
1832 headers: axum::http::HeaderMap,
1833 axum::extract::Path(name): axum::extract::Path<String>,
1834) -> impl IntoResponse {
1835 tracing::debug!(tool = %name, "GET /tools/:name");
1836
1837 let claims = claims.map(|Extension(claims)| claims);
1838 let scopes = scopes_for_request(claims.as_ref(), &state);
1839
1840 match state
1841 .registry
1842 .get_tool(&name)
1843 .filter(|(_, tool)| match &tool.scope {
1844 Some(scope) => scopes.is_allowed(scope),
1845 None => true,
1846 }) {
1847 Some((provider, tool)) => {
1848 let mut skills: Vec<String> = provider.skills.clone();
1850 for s in state.skill_registry.skills_for_tool(&tool.name) {
1851 if !skills.contains(&s.name) {
1852 skills.push(s.name.clone());
1853 }
1854 }
1855 for s in state.skill_registry.skills_for_provider(&provider.name) {
1856 if !skills.contains(&s.name) {
1857 skills.push(s.name.clone());
1858 }
1859 }
1860
1861 let input_schema = resolve_lazy_input_schema(&state, provider, tool, &headers).await;
1867
1868 (
1869 StatusCode::OK,
1870 Json(serde_json::json!({
1871 "name": tool.name,
1872 "description": tool.description,
1873 "provider": provider.name,
1874 "method": format!("{:?}", tool.method),
1875 "endpoint": tool.endpoint,
1876 "tags": tool.tags,
1877 "hint": tool.hint,
1878 "skills": skills,
1879 "input_schema": input_schema,
1880 "scope": tool.scope,
1881 })),
1882 )
1883 }
1884 None => (
1885 StatusCode::NOT_FOUND,
1886 Json(serde_json::json!({"error": format!("Tool '{name}' not found")})),
1887 ),
1888 }
1889}
1890
1891async fn resolve_lazy_input_schema(
1897 state: &ProxyState,
1898 provider: &Provider,
1899 tool: &Tool,
1900 headers: &axum::http::HeaderMap,
1901) -> Option<Value> {
1902 if let Some(ref schema) = tool.input_schema {
1909 return Some(schema.clone());
1910 }
1911 if provider.handler != "mcp" || provider.mcp_url_env.is_none() {
1913 return None;
1914 }
1915 let upstream_url = extract_upstream_url(state, provider, headers)?;
1916 let schemas = lazy_fetch_schemas(state, provider, &upstream_url).await?;
1917 schemas.get(&tool.name).cloned()
1918}
1919
1920async fn handle_skills_list(
1925 State(state): State<Arc<ProxyState>>,
1926 claims: Option<Extension<TokenClaims>>,
1927 axum::extract::Query(query): axum::extract::Query<SkillsQuery>,
1928) -> impl IntoResponse {
1929 tracing::debug!(
1930 category = ?query.category,
1931 provider = ?query.provider,
1932 tool = ?query.tool,
1933 search = ?query.search,
1934 "GET /skills"
1935 );
1936
1937 let claims = claims.map(|Extension(claims)| claims);
1938 let scopes = scopes_for_request(claims.as_ref(), &state);
1939 let visible_names = visible_skill_names(&state, &scopes);
1940
1941 let skills: Vec<&skill::SkillMeta> = if let Some(search_query) = &query.search {
1942 state
1943 .skill_registry
1944 .search(search_query)
1945 .into_iter()
1946 .filter(|skill| visible_names.contains(&skill.name))
1947 .collect()
1948 } else if let Some(cat) = &query.category {
1949 state
1950 .skill_registry
1951 .skills_for_category(cat)
1952 .into_iter()
1953 .filter(|skill| visible_names.contains(&skill.name))
1954 .collect()
1955 } else if let Some(prov) = &query.provider {
1956 state
1957 .skill_registry
1958 .skills_for_provider(prov)
1959 .into_iter()
1960 .filter(|skill| visible_names.contains(&skill.name))
1961 .collect()
1962 } else if let Some(t) = &query.tool {
1963 state
1964 .skill_registry
1965 .skills_for_tool(t)
1966 .into_iter()
1967 .filter(|skill| visible_names.contains(&skill.name))
1968 .collect()
1969 } else {
1970 state
1971 .skill_registry
1972 .list_skills()
1973 .iter()
1974 .filter(|skill| visible_names.contains(&skill.name))
1975 .collect()
1976 };
1977
1978 let json: Vec<Value> = skills
1979 .iter()
1980 .map(|s| {
1981 serde_json::json!({
1982 "name": s.name,
1983 "version": s.version,
1984 "description": s.description,
1985 "tools": s.tools,
1986 "providers": s.providers,
1987 "categories": s.categories,
1988 "hint": s.hint,
1989 })
1990 })
1991 .collect();
1992
1993 (StatusCode::OK, Json(Value::Array(json)))
1994}
1995
1996async fn handle_skill_detail(
1997 State(state): State<Arc<ProxyState>>,
1998 claims: Option<Extension<TokenClaims>>,
1999 axum::extract::Path(name): axum::extract::Path<String>,
2000 axum::extract::Query(query): axum::extract::Query<SkillDetailQuery>,
2001) -> impl IntoResponse {
2002 tracing::debug!(%name, meta = ?query.meta, refs = ?query.refs, "GET /skills/:name");
2003
2004 let claims = claims.map(|Extension(claims)| claims);
2005 let scopes = scopes_for_request(claims.as_ref(), &state);
2006 let visible_names = visible_skill_names(&state, &scopes);
2007
2008 let skill_meta = match state
2009 .skill_registry
2010 .get_skill(&name)
2011 .filter(|skill| visible_names.contains(&skill.name))
2012 {
2013 Some(s) => s,
2014 None => {
2015 return (
2016 StatusCode::NOT_FOUND,
2017 Json(serde_json::json!({"error": format!("Skill '{name}' not found")})),
2018 );
2019 }
2020 };
2021
2022 if query.meta.unwrap_or(false) {
2023 return (
2024 StatusCode::OK,
2025 Json(serde_json::json!({
2026 "name": skill_meta.name,
2027 "version": skill_meta.version,
2028 "description": skill_meta.description,
2029 "author": skill_meta.author,
2030 "tools": skill_meta.tools,
2031 "providers": skill_meta.providers,
2032 "categories": skill_meta.categories,
2033 "keywords": skill_meta.keywords,
2034 "hint": skill_meta.hint,
2035 "depends_on": skill_meta.depends_on,
2036 "suggests": skill_meta.suggests,
2037 "license": skill_meta.license,
2038 "compatibility": skill_meta.compatibility,
2039 "allowed_tools": skill_meta.allowed_tools,
2040 "format": skill_meta.format,
2041 })),
2042 );
2043 }
2044
2045 let content = match state.skill_registry.read_content(&name) {
2046 Ok(c) => c,
2047 Err(e) => {
2048 return (
2049 StatusCode::INTERNAL_SERVER_ERROR,
2050 Json(serde_json::json!({"error": format!("Failed to read skill: {e}")})),
2051 );
2052 }
2053 };
2054
2055 let mut response = serde_json::json!({
2056 "name": skill_meta.name,
2057 "version": skill_meta.version,
2058 "description": skill_meta.description,
2059 "content": content,
2060 });
2061
2062 if query.refs.unwrap_or(false) {
2063 if let Ok(refs) = state.skill_registry.list_references(&name) {
2064 response["references"] = serde_json::json!(refs);
2065 }
2066 }
2067
2068 (StatusCode::OK, Json(response))
2069}
2070
2071async fn handle_skill_bundle(
2075 State(state): State<Arc<ProxyState>>,
2076 claims: Option<Extension<TokenClaims>>,
2077 axum::extract::Path(name): axum::extract::Path<String>,
2078) -> impl IntoResponse {
2079 tracing::debug!(skill = %name, "GET /skills/:name/bundle");
2080
2081 let claims = claims.map(|Extension(claims)| claims);
2082 let scopes = scopes_for_request(claims.as_ref(), &state);
2083 let visible_names = visible_skill_names(&state, &scopes);
2084 if !visible_names.contains(&name) {
2085 return (
2086 StatusCode::NOT_FOUND,
2087 Json(serde_json::json!({"error": format!("Skill '{name}' not found")})),
2088 );
2089 }
2090
2091 let files = match state.skill_registry.bundle_files(&name) {
2092 Ok(f) => f,
2093 Err(_) => {
2094 return (
2095 StatusCode::NOT_FOUND,
2096 Json(serde_json::json!({"error": format!("Skill '{name}' not found")})),
2097 );
2098 }
2099 };
2100
2101 let mut file_map = serde_json::Map::new();
2103 for (path, data) in &files {
2104 match std::str::from_utf8(data) {
2105 Ok(text) => {
2106 file_map.insert(path.clone(), Value::String(text.to_string()));
2107 }
2108 Err(_) => {
2109 use base64::Engine;
2111 let encoded = base64::engine::general_purpose::STANDARD.encode(data);
2112 file_map.insert(path.clone(), serde_json::json!({"base64": encoded}));
2113 }
2114 }
2115 }
2116
2117 (
2118 StatusCode::OK,
2119 Json(serde_json::json!({
2120 "name": name,
2121 "files": file_map,
2122 })),
2123 )
2124}
2125
2126async fn handle_skills_bundle_batch(
2130 State(state): State<Arc<ProxyState>>,
2131 claims: Option<Extension<TokenClaims>>,
2132 Json(req): Json<SkillBundleBatchRequest>,
2133) -> impl IntoResponse {
2134 const MAX_BATCH: usize = 50;
2135 if req.names.len() > MAX_BATCH {
2136 return (
2137 StatusCode::BAD_REQUEST,
2138 Json(
2139 serde_json::json!({"error": format!("batch size {} exceeds limit of {MAX_BATCH}", req.names.len())}),
2140 ),
2141 );
2142 }
2143
2144 tracing::debug!(names = ?req.names, "POST /skills/bundle");
2145
2146 let claims = claims.map(|Extension(claims)| claims);
2147 let scopes = scopes_for_request(claims.as_ref(), &state);
2148 let visible_names = visible_skill_names(&state, &scopes);
2149
2150 let mut result = serde_json::Map::new();
2151 let mut missing: Vec<String> = Vec::new();
2152
2153 for name in &req.names {
2154 if !visible_names.contains(name) {
2155 missing.push(name.clone());
2156 continue;
2157 }
2158 let files = match state.skill_registry.bundle_files(name) {
2159 Ok(f) => f,
2160 Err(_) => {
2161 missing.push(name.clone());
2162 continue;
2163 }
2164 };
2165
2166 let mut file_map = serde_json::Map::new();
2167 for (path, data) in &files {
2168 match std::str::from_utf8(data) {
2169 Ok(text) => {
2170 file_map.insert(path.clone(), Value::String(text.to_string()));
2171 }
2172 Err(_) => {
2173 use base64::Engine;
2174 let encoded = base64::engine::general_purpose::STANDARD.encode(data);
2175 file_map.insert(path.clone(), serde_json::json!({"base64": encoded}));
2176 }
2177 }
2178 }
2179
2180 result.insert(name.clone(), serde_json::json!({ "files": file_map }));
2181 }
2182
2183 (
2184 StatusCode::OK,
2185 Json(serde_json::json!({ "skills": result, "missing": missing })),
2186 )
2187}
2188
2189async fn handle_skills_resolve(
2190 State(state): State<Arc<ProxyState>>,
2191 claims: Option<Extension<TokenClaims>>,
2192 Json(req): Json<SkillResolveRequest>,
2193) -> impl IntoResponse {
2194 tracing::debug!(scopes = ?req.scopes, include_content = req.include_content, "POST /skills/resolve");
2195
2196 let include_content = req.include_content;
2197 let request_scopes = ScopeConfig {
2198 scopes: req.scopes,
2199 sub: String::new(),
2200 expires_at: 0,
2201 rate_config: None,
2202 };
2203 let claims = claims.map(|Extension(claims)| claims);
2204 let caller_scopes = scopes_for_request(claims.as_ref(), &state);
2205 let visible_names = visible_skill_names(&state, &caller_scopes);
2206
2207 let resolved: Vec<&skill::SkillMeta> =
2208 skill::resolve_skills(&state.skill_registry, &state.registry, &request_scopes)
2209 .into_iter()
2210 .filter(|skill| visible_names.contains(&skill.name))
2211 .collect();
2212
2213 let json: Vec<Value> = resolved
2214 .iter()
2215 .map(|s| {
2216 let mut entry = serde_json::json!({
2217 "name": s.name,
2218 "version": s.version,
2219 "description": s.description,
2220 "tools": s.tools,
2221 "providers": s.providers,
2222 "categories": s.categories,
2223 });
2224 if include_content {
2225 if let Ok(content) = state.skill_registry.read_content(&s.name) {
2226 entry["content"] = Value::String(content);
2227 }
2228 }
2229 entry
2230 })
2231 .collect();
2232
2233 (StatusCode::OK, Json(Value::Array(json)))
2234}
2235
2236fn skillati_client(keyring: &Keyring) -> Result<SkillAtiClient, SkillAtiError> {
2237 match SkillAtiClient::from_env(keyring)? {
2238 Some(client) => Ok(client),
2239 None => Err(SkillAtiError::NotConfigured),
2240 }
2241}
2242
2243async fn handle_skillati_catalog(
2244 State(state): State<Arc<ProxyState>>,
2245 claims: Option<Extension<TokenClaims>>,
2246 Query(query): Query<SkillAtiCatalogQuery>,
2247) -> impl IntoResponse {
2248 tracing::debug!(search = ?query.search, "GET /skillati/catalog");
2249
2250 let client = match skillati_client(&state.keyring) {
2251 Ok(client) => client,
2252 Err(err) => return skillati_error_response(err),
2253 };
2254
2255 let claims = claims.map(|Extension(c)| c);
2256 let scopes = scopes_for_request(claims.as_ref(), &state);
2257
2258 match client.catalog().await {
2259 Ok(catalog) => {
2260 let mut visible_names = visible_skill_names(&state, &scopes);
2264 visible_names.extend(visible_remote_skill_names(&state, &scopes, &catalog));
2265
2266 let mut skills: Vec<_> = catalog
2267 .into_iter()
2268 .filter(|s| visible_names.contains(&s.name))
2269 .collect();
2270 if let Some(search) = query.search.as_deref() {
2271 skills = SkillAtiClient::filter_catalog(&skills, search, 25);
2272 }
2273 (
2274 StatusCode::OK,
2275 Json(serde_json::json!({
2276 "skills": skills,
2277 })),
2278 )
2279 }
2280 Err(err) => skillati_error_response(err),
2281 }
2282}
2283
2284async fn handle_skillati_read(
2285 State(state): State<Arc<ProxyState>>,
2286 claims: Option<Extension<TokenClaims>>,
2287 axum::extract::Path(name): axum::extract::Path<String>,
2288) -> impl IntoResponse {
2289 tracing::debug!(%name, "GET /skillati/:name");
2290
2291 let client = match skillati_client(&state.keyring) {
2292 Ok(client) => client,
2293 Err(err) => return skillati_error_response(err),
2294 };
2295
2296 let claims = claims.map(|Extension(c)| c);
2297 let scopes = scopes_for_request(claims.as_ref(), &state);
2298 let visible_names = match visible_skill_names_with_remote(&state, &scopes, &client).await {
2299 Ok(v) => v,
2300 Err(err) => return skillati_error_response(err),
2301 };
2302 if !visible_names.contains(&name) {
2303 return skillati_error_response(SkillAtiError::SkillNotFound(name));
2304 }
2305
2306 match client.read_skill(&name).await {
2307 Ok(activation) => (StatusCode::OK, Json(serde_json::json!(activation))),
2308 Err(err) => skillati_error_response(err),
2309 }
2310}
2311
2312async fn handle_skillati_resources(
2313 State(state): State<Arc<ProxyState>>,
2314 claims: Option<Extension<TokenClaims>>,
2315 axum::extract::Path(name): axum::extract::Path<String>,
2316 Query(query): Query<SkillAtiResourcesQuery>,
2317) -> impl IntoResponse {
2318 tracing::debug!(%name, prefix = ?query.prefix, "GET /skillati/:name/resources");
2319
2320 let client = match skillati_client(&state.keyring) {
2321 Ok(client) => client,
2322 Err(err) => return skillati_error_response(err),
2323 };
2324
2325 let claims = claims.map(|Extension(c)| c);
2326 let scopes = scopes_for_request(claims.as_ref(), &state);
2327 let visible_names = match visible_skill_names_with_remote(&state, &scopes, &client).await {
2328 Ok(v) => v,
2329 Err(err) => return skillati_error_response(err),
2330 };
2331 if !visible_names.contains(&name) {
2332 return skillati_error_response(SkillAtiError::SkillNotFound(name));
2333 }
2334
2335 match client.list_resources(&name, query.prefix.as_deref()).await {
2336 Ok(resources) => (
2337 StatusCode::OK,
2338 Json(serde_json::json!({
2339 "name": name,
2340 "prefix": query.prefix,
2341 "resources": resources,
2342 })),
2343 ),
2344 Err(err) => skillati_error_response(err),
2345 }
2346}
2347
2348async fn handle_skillati_file(
2349 State(state): State<Arc<ProxyState>>,
2350 claims: Option<Extension<TokenClaims>>,
2351 axum::extract::Path(name): axum::extract::Path<String>,
2352 Query(query): Query<SkillAtiFileQuery>,
2353) -> impl IntoResponse {
2354 tracing::debug!(%name, path = %query.path, "GET /skillati/:name/file");
2355
2356 let client = match skillati_client(&state.keyring) {
2357 Ok(client) => client,
2358 Err(err) => return skillati_error_response(err),
2359 };
2360
2361 let claims = claims.map(|Extension(c)| c);
2362 let scopes = scopes_for_request(claims.as_ref(), &state);
2363 let visible_names = match visible_skill_names_with_remote(&state, &scopes, &client).await {
2364 Ok(v) => v,
2365 Err(err) => return skillati_error_response(err),
2366 };
2367 if !visible_names.contains(&name) {
2368 return skillati_error_response(SkillAtiError::SkillNotFound(name));
2369 }
2370
2371 match client.read_path(&name, &query.path).await {
2372 Ok(file) => (StatusCode::OK, Json(serde_json::json!(file))),
2373 Err(err) => skillati_error_response(err),
2374 }
2375}
2376
2377async fn handle_skillati_refs(
2378 State(state): State<Arc<ProxyState>>,
2379 claims: Option<Extension<TokenClaims>>,
2380 axum::extract::Path(name): axum::extract::Path<String>,
2381) -> impl IntoResponse {
2382 tracing::debug!(%name, "GET /skillati/:name/refs");
2383
2384 let client = match skillati_client(&state.keyring) {
2385 Ok(client) => client,
2386 Err(err) => return skillati_error_response(err),
2387 };
2388
2389 let claims = claims.map(|Extension(c)| c);
2390 let scopes = scopes_for_request(claims.as_ref(), &state);
2391 let visible_names = match visible_skill_names_with_remote(&state, &scopes, &client).await {
2392 Ok(v) => v,
2393 Err(err) => return skillati_error_response(err),
2394 };
2395 if !visible_names.contains(&name) {
2396 return skillati_error_response(SkillAtiError::SkillNotFound(name));
2397 }
2398
2399 match client.list_references(&name).await {
2400 Ok(references) => (
2401 StatusCode::OK,
2402 Json(serde_json::json!({
2403 "name": name,
2404 "references": references,
2405 })),
2406 ),
2407 Err(err) => skillati_error_response(err),
2408 }
2409}
2410
2411async fn handle_skillati_ref(
2412 State(state): State<Arc<ProxyState>>,
2413 claims: Option<Extension<TokenClaims>>,
2414 axum::extract::Path((name, reference)): axum::extract::Path<(String, String)>,
2415) -> impl IntoResponse {
2416 tracing::debug!(%name, %reference, "GET /skillati/:name/ref/:reference");
2417
2418 let client = match skillati_client(&state.keyring) {
2419 Ok(client) => client,
2420 Err(err) => return skillati_error_response(err),
2421 };
2422
2423 let claims = claims.map(|Extension(c)| c);
2424 let scopes = scopes_for_request(claims.as_ref(), &state);
2425 let visible_names = match visible_skill_names_with_remote(&state, &scopes, &client).await {
2426 Ok(v) => v,
2427 Err(err) => return skillati_error_response(err),
2428 };
2429 if !visible_names.contains(&name) {
2430 return skillati_error_response(SkillAtiError::SkillNotFound(name));
2431 }
2432
2433 match client.read_reference(&name, &reference).await {
2434 Ok(content) => (
2435 StatusCode::OK,
2436 Json(serde_json::json!({
2437 "name": name,
2438 "reference": reference,
2439 "content": content,
2440 })),
2441 ),
2442 Err(err) => skillati_error_response(err),
2443 }
2444}
2445
2446fn skillati_error_response(err: SkillAtiError) -> (StatusCode, Json<Value>) {
2447 let status = match &err {
2448 SkillAtiError::NotConfigured
2449 | SkillAtiError::UnsupportedRegistry(_)
2450 | SkillAtiError::MissingCredentials(_)
2451 | SkillAtiError::ProxyUrlRequired => StatusCode::SERVICE_UNAVAILABLE,
2452 SkillAtiError::SkillNotFound(_) | SkillAtiError::PathNotFound { .. } => {
2453 StatusCode::NOT_FOUND
2454 }
2455 SkillAtiError::InvalidPath(_) => StatusCode::BAD_REQUEST,
2456 SkillAtiError::Gcs(_)
2457 | SkillAtiError::ProxyRequest(_)
2458 | SkillAtiError::ProxyResponse(_) => StatusCode::BAD_GATEWAY,
2459 };
2460
2461 (
2462 status,
2463 Json(serde_json::json!({
2464 "error": err.to_string(),
2465 })),
2466 )
2467}
2468
2469async fn auth_middleware(
2477 State(state): State<Arc<ProxyState>>,
2478 mut req: HttpRequest<Body>,
2479 next: Next,
2480) -> Result<Response, StatusCode> {
2481 let path = req.uri().path();
2482
2483 if path == "/health" || path == "/.well-known/jwks.json" {
2485 return Ok(next.run(req).await);
2486 }
2487
2488 let jwt_config = match &state.jwt_config {
2490 Some(c) => c,
2491 None => return Ok(next.run(req).await),
2492 };
2493
2494 let token_owned: String = match req
2499 .headers()
2500 .get("authorization")
2501 .and_then(|v| v.to_str().ok())
2502 {
2503 Some(header) if header.starts_with("Bearer ") => header[7..].to_string(),
2504 _ => return Err(StatusCode::UNAUTHORIZED),
2505 };
2506
2507 match jwt::validate(&token_owned, jwt_config) {
2509 Ok(claims) => {
2510 tracing::debug!(sub = %claims.sub, scopes = %claims.scope, "JWT validated");
2511 req.extensions_mut().insert(BearerToken(token_owned));
2517 req.extensions_mut().insert(claims);
2518 Ok(next.run(req).await)
2519 }
2520 Err(e) => {
2521 tracing::debug!(error = %e, "JWT validation failed");
2522 Err(StatusCode::UNAUTHORIZED)
2523 }
2524 }
2525}
2526
2527#[derive(Debug, Clone)]
2534pub struct BearerToken(pub String);
2535
2536fn max_call_body_bytes() -> usize {
2546 (crate::core::file_manager::MAX_UPLOAD_BYTES as usize)
2547 .saturating_mul(4)
2548 .saturating_div(3)
2549 .saturating_add(8 * 1024)
2550}
2551
2552pub fn build_router(state: Arc<ProxyState>) -> Router {
2553 use axum::extract::DefaultBodyLimit;
2554
2555 Router::new()
2556 .route("/call", post(handle_call))
2557 .route("/help", post(handle_help))
2558 .route("/mcp", post(handle_mcp))
2559 .route("/tools", get(handle_tools_list))
2560 .route("/tools/{name}", get(handle_tool_info))
2561 .route("/skills", get(handle_skills_list))
2562 .route("/skills/resolve", post(handle_skills_resolve))
2563 .route("/skills/bundle", post(handle_skills_bundle_batch))
2564 .route("/skills/{name}", get(handle_skill_detail))
2565 .route("/skills/{name}/bundle", get(handle_skill_bundle))
2566 .route("/skillati/catalog", get(handle_skillati_catalog))
2567 .route("/skillati/{name}", get(handle_skillati_read))
2568 .route("/skillati/{name}/resources", get(handle_skillati_resources))
2569 .route("/skillati/{name}/file", get(handle_skillati_file))
2570 .route("/skillati/{name}/refs", get(handle_skillati_refs))
2571 .route("/skillati/{name}/ref/{reference}", get(handle_skillati_ref))
2572 .route("/health", get(handle_health))
2573 .route("/.well-known/jwks.json", get(handle_jwks))
2574 .layer(DefaultBodyLimit::max(max_call_body_bytes()))
2579 .layer(middleware::from_fn_with_state(
2580 state.clone(),
2581 auth_middleware,
2582 ))
2583 .with_state(state)
2584}
2585
2586pub async fn run(
2590 port: u16,
2591 bind_addr: Option<String>,
2592 ati_dir: PathBuf,
2593 _verbose: bool,
2594 env_keys: bool,
2595) -> Result<(), Box<dyn std::error::Error>> {
2596 let manifests_dir = ati_dir.join("manifests");
2598 let mut registry = ManifestRegistry::load(&manifests_dir)?;
2599 let provider_count = registry.list_providers().len();
2600
2601 let keyring_source;
2603 let keyring = if env_keys {
2604 let kr = Keyring::from_env();
2606 let key_names = kr.key_names();
2607 tracing::info!(
2608 count = key_names.len(),
2609 "loaded API keys from ATI_KEY_* env vars"
2610 );
2611 for name in &key_names {
2612 tracing::debug!(key = %name, "env key loaded");
2613 }
2614 keyring_source = "env-vars (ATI_KEY_*)";
2615 kr
2616 } else {
2617 let keyring_path = ati_dir.join("keyring.enc");
2619 if keyring_path.exists() {
2620 if let Ok(kr) = Keyring::load(&keyring_path) {
2621 keyring_source = "keyring.enc (sealed key)";
2622 kr
2623 } else if let Ok(kr) = Keyring::load_local(&keyring_path, &ati_dir) {
2624 keyring_source = "keyring.enc (persistent key)";
2625 kr
2626 } else {
2627 tracing::warn!("keyring.enc exists but could not be decrypted");
2628 keyring_source = "empty (decryption failed)";
2629 Keyring::empty()
2630 }
2631 } else {
2632 let creds_path = ati_dir.join("credentials");
2633 if creds_path.exists() {
2634 match Keyring::load_credentials(&creds_path) {
2635 Ok(kr) => {
2636 keyring_source = "credentials (plaintext)";
2637 kr
2638 }
2639 Err(e) => {
2640 tracing::warn!(error = %e, "failed to load credentials");
2641 keyring_source = "empty (credentials error)";
2642 Keyring::empty()
2643 }
2644 }
2645 } else {
2646 tracing::warn!("no keyring.enc or credentials found — running without API keys");
2647 tracing::warn!("tools requiring authentication will fail");
2648 keyring_source = "empty (no auth)";
2649 Keyring::empty()
2650 }
2651 }
2652 };
2653
2654 mcp_client::discover_all_mcp_tools(&mut registry, &keyring).await;
2657
2658 let tool_count = registry.list_public_tools().len();
2659
2660 let mcp_providers: Vec<(String, String)> = registry
2662 .list_mcp_providers()
2663 .iter()
2664 .map(|p| (p.name.clone(), p.mcp_transport_type().to_string()))
2665 .collect();
2666 let mcp_count = mcp_providers.len();
2667 let openapi_providers: Vec<String> = registry
2668 .list_openapi_providers()
2669 .iter()
2670 .map(|p| p.name.clone())
2671 .collect();
2672 let openapi_count = openapi_providers.len();
2673
2674 let skills_dir = ati_dir.join("skills");
2676 let skill_registry = SkillRegistry::load(&skills_dir).unwrap_or_else(|e| {
2677 tracing::warn!(error = %e, "failed to load skills");
2678 SkillRegistry::load(std::path::Path::new("/nonexistent-fallback")).unwrap()
2679 });
2680
2681 if let Ok(registry_url) = std::env::var("ATI_SKILL_REGISTRY") {
2682 if registry_url.strip_prefix("gcs://").is_some() {
2683 tracing::info!(
2684 registry = %registry_url,
2685 "SkillATI remote registry configured for lazy reads"
2686 );
2687 } else {
2688 tracing::warn!(url = %registry_url, "SkillATI only supports gcs:// registries");
2689 }
2690 }
2691
2692 let skill_count = skill_registry.skill_count();
2693
2694 let jwt_config = match jwt::config_from_env() {
2696 Ok(config) => config,
2697 Err(e) => {
2698 tracing::warn!(error = %e, "JWT config error");
2699 None
2700 }
2701 };
2702
2703 let auth_status = if jwt_config.is_some() {
2704 "JWT enabled"
2705 } else {
2706 "DISABLED (no JWT keys configured)"
2707 };
2708
2709 let jwks_json = jwt_config.as_ref().and_then(|config| {
2711 config
2712 .public_key_pem
2713 .as_ref()
2714 .and_then(|pem| jwt::public_key_to_jwks(pem, config.algorithm, "ati-proxy-1").ok())
2715 });
2716
2717 let state = Arc::new(ProxyState {
2718 registry,
2719 skill_registry,
2720 keyring,
2721 jwt_config,
2722 jwks_json,
2723 auth_cache: AuthCache::new(),
2724 upstream_url_allowlists: std::sync::Arc::new(std::sync::Mutex::new(
2725 std::collections::HashMap::new(),
2726 )),
2727 lazy_schema_cache: std::sync::Arc::new(std::sync::Mutex::new(
2728 std::collections::HashMap::new(),
2729 )),
2730 });
2731
2732 let app = build_router(state);
2733
2734 let addr: SocketAddr = if let Some(ref bind) = bind_addr {
2735 format!("{bind}:{port}").parse()?
2736 } else {
2737 SocketAddr::from(([127, 0, 0, 1], port))
2738 };
2739
2740 tracing::info!(
2741 version = env!("CARGO_PKG_VERSION"),
2742 %addr,
2743 auth = auth_status,
2744 ati_dir = %ati_dir.display(),
2745 tools = tool_count,
2746 providers = provider_count,
2747 mcp = mcp_count,
2748 openapi = openapi_count,
2749 skills = skill_count,
2750 keyring = keyring_source,
2751 "ATI proxy server starting"
2752 );
2753 for (name, transport) in &mcp_providers {
2754 tracing::info!(provider = %name, transport = %transport, "MCP provider");
2755 }
2756 for name in &openapi_providers {
2757 tracing::info!(provider = %name, "OpenAPI provider");
2758 }
2759
2760 let listener = tokio::net::TcpListener::bind(addr).await?;
2761 axum::serve(listener, app).await?;
2762
2763 Ok(())
2764}
2765
2766async fn dispatch_file_manager(
2769 tool_name: &str,
2770 args: &HashMap<String, Value>,
2771 provider: &Provider,
2772 keyring: &Keyring,
2773) -> Result<Value, (StatusCode, String)> {
2774 use crate::core::file_manager::{self, DownloadArgs, FileManagerError, UploadArgs};
2775
2776 let to_resp = |e: FileManagerError| {
2779 let status =
2780 StatusCode::from_u16(e.http_status()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
2781 (status, e.to_string())
2782 };
2783
2784 match tool_name {
2785 "file_manager:download" => {
2786 let parsed = DownloadArgs::from_value(args).map_err(to_resp)?;
2787 let result = file_manager::fetch_bytes(&parsed).await.map_err(to_resp)?;
2788 Ok(file_manager::build_download_response(&result))
2789 }
2790 "file_manager:upload" => {
2791 let parsed = UploadArgs::from_wire(args).map_err(to_resp)?;
2792 file_manager::upload_to_destination(
2793 parsed,
2794 &provider.upload_destinations,
2795 provider.upload_default_destination.as_deref(),
2796 keyring,
2797 )
2798 .await
2799 .map_err(to_resp)
2800 }
2801 other => Err((
2802 StatusCode::NOT_FOUND,
2803 format!("Unknown file_manager tool: '{other}'"),
2804 )),
2805 }
2806}
2807
2808fn write_proxy_audit(
2809 call_req: &CallRequest,
2810 agent_sub: &str,
2811 claims: Option<&TokenClaims>,
2812 duration: std::time::Duration,
2813 error: Option<&str>,
2814) {
2815 let entry = crate::core::audit::AuditEntry {
2816 ts: chrono::Utc::now().to_rfc3339(),
2817 tool: call_req.tool_name.clone(),
2818 args: crate::core::audit::sanitize_args(&call_req.args),
2819 status: if error.is_some() {
2820 crate::core::audit::AuditStatus::Error
2821 } else {
2822 crate::core::audit::AuditStatus::Ok
2823 },
2824 duration_ms: duration.as_millis() as u64,
2825 agent_sub: agent_sub.to_string(),
2826 job_id: claims.and_then(|c| c.job_id.clone()),
2827 sandbox_id: claims.and_then(|c| c.sandbox_id.clone()),
2828 error: error.map(|s| s.to_string()),
2829 exit_code: None,
2830 };
2831 let _ = crate::core::audit::append(&entry);
2832}
2833
2834const HELP_SYSTEM_PROMPT: &str = r#"You are a helpful assistant for an AI agent that uses external tools via the `ati` CLI.
2837
2838## Available Tools
2839{tools}
2840
2841{skills_section}
2842
2843Answer the agent's question naturally, like a knowledgeable colleague would. Keep it short but useful:
2844
2845- Explain which tools to use and why, with `ati run` commands showing realistic parameter values
2846- If multiple steps are needed, walk through them briefly in order
2847- Mention important gotchas or parameter choices that matter
2848- If skills are relevant, tell the agent to load them using the Skill tool (e.g., `skill: "research-financial-data"`)
2849
2850Keep your answer concise — a few short paragraphs with embedded code blocks. Only recommend tools from the list above."#;
2851
2852async fn build_remote_skillati_section(keyring: &Keyring, query: &str, limit: usize) -> String {
2853 let client = match SkillAtiClient::from_env(keyring) {
2854 Ok(Some(client)) => client,
2855 Ok(None) => return String::new(),
2856 Err(err) => {
2857 tracing::warn!(error = %err, "failed to initialize SkillATI catalog for proxy help");
2858 return String::new();
2859 }
2860 };
2861
2862 let catalog = match client.catalog().await {
2863 Ok(catalog) => catalog,
2864 Err(err) => {
2865 tracing::warn!(error = %err, "failed to load SkillATI catalog for proxy help");
2866 return String::new();
2867 }
2868 };
2869
2870 let matched = SkillAtiClient::filter_catalog(&catalog, query, limit);
2871 if matched.is_empty() {
2872 return String::new();
2873 }
2874
2875 render_remote_skillati_section(&matched, catalog.len())
2876}
2877
2878fn render_remote_skillati_section(skills: &[RemoteSkillMeta], total_catalog: usize) -> String {
2879 let mut section = String::from("## Remote Skills Available Via SkillATI\n\n");
2880 section.push_str(
2881 "These skills are available. Load them using the Skill tool (e.g., `skill: \"skill-name\"`).\n\n",
2882 );
2883
2884 for skill in skills {
2885 section.push_str(&format!("- **{}**: {}\n", skill.name, skill.description));
2886 }
2887
2888 if total_catalog > skills.len() {
2889 section.push_str(&format!(
2890 "\nOnly the most relevant {} remote skills are shown here.\n",
2891 skills.len()
2892 ));
2893 }
2894
2895 section
2896}
2897
2898fn merge_help_skill_sections(sections: &[String]) -> String {
2899 sections
2900 .iter()
2901 .filter_map(|section| {
2902 let trimmed = section.trim();
2903 if trimmed.is_empty() {
2904 None
2905 } else {
2906 Some(trimmed.to_string())
2907 }
2908 })
2909 .collect::<Vec<_>>()
2910 .join("\n\n")
2911}
2912
2913fn build_tool_context(
2914 tools: &[(
2915 &crate::core::manifest::Provider,
2916 &crate::core::manifest::Tool,
2917 )],
2918) -> String {
2919 let mut summaries = Vec::new();
2920 for (provider, tool) in tools {
2921 let mut summary = if let Some(cat) = &provider.category {
2922 format!(
2923 "- **{}** (provider: {}, category: {}): {}",
2924 tool.name, provider.name, cat, tool.description
2925 )
2926 } else {
2927 format!(
2928 "- **{}** (provider: {}): {}",
2929 tool.name, provider.name, tool.description
2930 )
2931 };
2932 if !tool.tags.is_empty() {
2933 summary.push_str(&format!("\n Tags: {}", tool.tags.join(", ")));
2934 }
2935 if provider.is_cli() && tool.input_schema.is_none() {
2937 let cmd = provider.cli_command.as_deref().unwrap_or("?");
2938 summary.push_str(&format!(
2939 "\n Usage: `ati run {} -- <args>` (passthrough to `{}`)",
2940 tool.name, cmd
2941 ));
2942 } else if let Some(schema) = &tool.input_schema {
2943 if let Some(props) = schema.get("properties") {
2944 if let Some(obj) = props.as_object() {
2945 let params: Vec<String> = obj
2946 .iter()
2947 .filter(|(_, v)| {
2948 v.get("x-ati-param-location").is_none()
2949 || v.get("description").is_some()
2950 })
2951 .map(|(k, v)| {
2952 let type_str =
2953 v.get("type").and_then(|t| t.as_str()).unwrap_or("string");
2954 let desc = v.get("description").and_then(|d| d.as_str()).unwrap_or("");
2955 format!(" --{k} ({type_str}): {desc}")
2956 })
2957 .collect();
2958 if !params.is_empty() {
2959 summary.push_str("\n Parameters:\n");
2960 summary.push_str(¶ms.join("\n"));
2961 }
2962 }
2963 }
2964 }
2965 summaries.push(summary);
2966 }
2967 summaries.join("\n\n")
2968}
2969
2970fn build_scoped_prompt(
2974 scope_name: &str,
2975 visible_tools: &[(&Provider, &Tool)],
2976 skills_section: &str,
2977) -> Option<String> {
2978 if let Some((provider, tool)) = visible_tools
2980 .iter()
2981 .find(|(_, tool)| tool.name == scope_name)
2982 {
2983 let mut details = format!(
2984 "**Name**: `{}`\n**Provider**: {} (handler: {})\n**Description**: {}\n",
2985 tool.name, provider.name, provider.handler, tool.description
2986 );
2987 if let Some(cat) = &provider.category {
2988 details.push_str(&format!("**Category**: {}\n", cat));
2989 }
2990 if provider.is_cli() {
2991 let cmd = provider.cli_command.as_deref().unwrap_or("?");
2992 details.push_str(&format!(
2993 "\n**Usage**: `ati run {} -- <args>` (passthrough to `{}`)\n",
2994 tool.name, cmd
2995 ));
2996 } else if let Some(schema) = &tool.input_schema {
2997 if let Some(props) = schema.get("properties").and_then(|p| p.as_object()) {
2998 let required: Vec<String> = schema
2999 .get("required")
3000 .and_then(|r| r.as_array())
3001 .map(|arr| {
3002 arr.iter()
3003 .filter_map(|v| v.as_str().map(|s| s.to_string()))
3004 .collect()
3005 })
3006 .unwrap_or_default();
3007 details.push_str("\n**Parameters**:\n");
3008 for (key, val) in props {
3009 let type_str = val.get("type").and_then(|t| t.as_str()).unwrap_or("string");
3010 let desc = val
3011 .get("description")
3012 .and_then(|d| d.as_str())
3013 .unwrap_or("");
3014 let req = if required.contains(key) {
3015 " **(required)**"
3016 } else {
3017 ""
3018 };
3019 details.push_str(&format!("- `--{key}` ({type_str}{req}): {desc}\n"));
3020 }
3021 }
3022 }
3023
3024 let prompt = format!(
3025 "You are an expert assistant for the `{}` tool, accessed via the `ati` CLI.\n\n\
3026 ## Tool Details\n{}\n\n{}\n\n\
3027 Answer the agent's question about this specific tool. Provide exact commands, explain flags and options, and give practical examples. Be concise and actionable.",
3028 tool.name, details, skills_section
3029 );
3030 return Some(prompt);
3031 }
3032
3033 let tools: Vec<(&Provider, &Tool)> = visible_tools
3035 .iter()
3036 .copied()
3037 .filter(|(provider, _)| provider.name == scope_name)
3038 .collect();
3039 if !tools.is_empty() {
3040 let tools_context = build_tool_context(&tools);
3041 let prompt = format!(
3042 "You are an expert assistant for the `{}` provider's tools, accessed via the `ati` CLI.\n\n\
3043 ## Tools in provider `{}`\n{}\n\n{}\n\n\
3044 Answer the agent's question about these tools. Provide exact `ati run` commands, explain parameters, and give practical examples. Be concise and actionable.",
3045 scope_name, scope_name, tools_context, skills_section
3046 );
3047 return Some(prompt);
3048 }
3049
3050 None
3051}