1use axum::{
8 body::Body,
9 extract::{Extension, Query, State},
10 http::{Request as HttpRequest, StatusCode},
11 middleware::{self, Next},
12 response::{IntoResponse, Response},
13 routing::{get, post},
14 Json, Router,
15};
16use serde::{Deserialize, Serialize};
17use serde_json::Value;
18use std::collections::HashMap;
19use std::net::SocketAddr;
20use std::path::PathBuf;
21use std::sync::Arc;
22
23use crate::core::auth_generator::{AuthCache, GenContext};
24use crate::core::http;
25use crate::core::jwt::{self, JwtConfig, TokenClaims};
26use crate::core::keyring::Keyring;
27use crate::core::manifest::{ManifestRegistry, Provider, Tool};
28use crate::core::mcp_client;
29use crate::core::response;
30use crate::core::scope::ScopeConfig;
31use crate::core::sentry_scope;
32use crate::core::skill::{self, SkillRegistry};
33use crate::core::skillati::{RemoteSkillMeta, SkillAtiClient, SkillAtiError};
34
35pub struct ProxyState {
37 pub registry: ManifestRegistry,
38 pub skill_registry: SkillRegistry,
39 pub keyring: Keyring,
40 pub jwt_config: Option<JwtConfig>,
42 pub jwks_json: Option<Value>,
44 pub auth_cache: AuthCache,
46}
47
48#[derive(Debug, Deserialize)]
51pub struct CallRequest {
52 pub tool_name: String,
53 #[serde(default = "default_args")]
57 pub args: Value,
58 #[serde(default)]
61 pub raw_args: Option<Vec<String>>,
62}
63
64fn default_args() -> Value {
65 Value::Object(serde_json::Map::new())
66}
67
68impl CallRequest {
69 fn args_as_map(&self) -> HashMap<String, Value> {
73 match &self.args {
74 Value::Object(map) => map.iter().map(|(k, v)| (k.clone(), v.clone())).collect(),
75 _ => HashMap::new(),
76 }
77 }
78
79 fn args_as_positional(&self) -> Vec<String> {
82 if let Some(ref raw) = self.raw_args {
84 return raw.clone();
85 }
86 match &self.args {
87 Value::Array(arr) => arr
89 .iter()
90 .map(|v| match v {
91 Value::String(s) => s.clone(),
92 other => other.to_string(),
93 })
94 .collect(),
95 Value::String(s) => s.split_whitespace().map(String::from).collect(),
97 Value::Object(map) => {
99 if let Some(Value::Array(pos)) = map.get("_positional") {
100 return pos
101 .iter()
102 .map(|v| match v {
103 Value::String(s) => s.clone(),
104 other => other.to_string(),
105 })
106 .collect();
107 }
108 let mut result = Vec::new();
110 for (k, v) in map {
111 result.push(format!("--{k}"));
112 match v {
113 Value::String(s) => result.push(s.clone()),
114 Value::Bool(true) => {} other => result.push(other.to_string()),
116 }
117 }
118 result
119 }
120 _ => Vec::new(),
121 }
122 }
123}
124
125#[derive(Debug, Serialize)]
126pub struct CallResponse {
127 pub result: Value,
128 #[serde(skip_serializing_if = "Option::is_none")]
129 pub error: Option<String>,
130}
131
132#[derive(Debug, Deserialize)]
133pub struct HelpRequest {
134 pub query: String,
135 #[serde(default)]
136 pub tool: Option<String>,
137}
138
139#[derive(Debug, Serialize)]
140pub struct HelpResponse {
141 pub content: String,
142 #[serde(skip_serializing_if = "Option::is_none")]
143 pub error: Option<String>,
144}
145
146#[derive(Debug, Serialize)]
147pub struct HealthResponse {
148 pub status: String,
149 pub version: String,
150 pub tools: usize,
151 pub providers: usize,
152 pub skills: usize,
153 pub auth: String,
154}
155
156#[derive(Debug, Deserialize)]
159pub struct SkillsQuery {
160 #[serde(default)]
161 pub category: Option<String>,
162 #[serde(default)]
163 pub provider: Option<String>,
164 #[serde(default)]
165 pub tool: Option<String>,
166 #[serde(default)]
167 pub search: Option<String>,
168}
169
170#[derive(Debug, Deserialize)]
171pub struct SkillDetailQuery {
172 #[serde(default)]
173 pub meta: Option<bool>,
174 #[serde(default)]
175 pub refs: Option<bool>,
176}
177
178#[derive(Debug, Deserialize)]
179pub struct SkillResolveRequest {
180 pub scopes: Vec<String>,
181 #[serde(default)]
183 pub include_content: bool,
184}
185
186#[derive(Debug, Deserialize)]
187pub struct SkillBundleBatchRequest {
188 pub names: Vec<String>,
189}
190
191#[derive(Debug, Deserialize, Default)]
192pub struct SkillAtiCatalogQuery {
193 #[serde(default)]
194 pub search: Option<String>,
195}
196
197#[derive(Debug, Deserialize, Default)]
198pub struct SkillAtiResourcesQuery {
199 #[serde(default)]
200 pub prefix: Option<String>,
201}
202
203#[derive(Debug, Deserialize)]
204pub struct SkillAtiFileQuery {
205 pub path: String,
206}
207
208#[derive(Debug, Deserialize)]
211pub struct ToolsQuery {
212 #[serde(default)]
213 pub provider: Option<String>,
214 #[serde(default)]
215 pub search: Option<String>,
216}
217
218fn scopes_for_request(claims: Option<&TokenClaims>, state: &ProxyState) -> ScopeConfig {
221 match claims {
222 Some(claims) => ScopeConfig::from_jwt(claims),
223 None if state.jwt_config.is_none() => ScopeConfig::unrestricted(),
224 None => ScopeConfig {
225 scopes: Vec::new(),
226 sub: String::new(),
227 expires_at: 0,
228 rate_config: None,
229 },
230 }
231}
232
233fn visible_tools_for_scopes<'a>(
234 state: &'a ProxyState,
235 scopes: &ScopeConfig,
236) -> Vec<(&'a Provider, &'a Tool)> {
237 crate::core::scope::filter_tools_by_scope(state.registry.list_public_tools(), scopes)
238}
239
240fn visible_skill_names(
241 state: &ProxyState,
242 scopes: &ScopeConfig,
243) -> std::collections::HashSet<String> {
244 skill::visible_skills(&state.skill_registry, &state.registry, scopes)
245 .into_iter()
246 .map(|skill| skill.name.clone())
247 .collect()
248}
249
250fn visible_remote_skill_names(
262 state: &ProxyState,
263 scopes: &ScopeConfig,
264 catalog: &[RemoteSkillMeta],
265) -> std::collections::HashSet<String> {
266 let mut visible: std::collections::HashSet<String> = std::collections::HashSet::new();
267 if catalog.is_empty() {
268 return visible;
269 }
270 if scopes.is_wildcard() {
271 for entry in catalog {
272 visible.insert(entry.name.clone());
273 }
274 return visible;
275 }
276
277 let allowed_tool_pairs: Vec<(String, String)> =
281 crate::core::scope::filter_tools_by_scope(state.registry.list_public_tools(), scopes)
282 .into_iter()
283 .map(|(p, t)| (p.name.clone(), t.name.clone()))
284 .collect();
285 let allowed_tool_names: std::collections::HashSet<&str> =
286 allowed_tool_pairs.iter().map(|(_, t)| t.as_str()).collect();
287 let allowed_provider_names: std::collections::HashSet<&str> =
288 allowed_tool_pairs.iter().map(|(p, _)| p.as_str()).collect();
289 let allowed_categories: std::collections::HashSet<String> = state
290 .registry
291 .list_providers()
292 .into_iter()
293 .filter(|p| allowed_provider_names.contains(p.name.as_str()))
294 .filter_map(|p| p.category.clone())
295 .collect();
296
297 for scope in &scopes.scopes {
299 if let Some(skill_name) = scope.strip_prefix("skill:") {
300 if catalog.iter().any(|e| e.name == skill_name) {
301 visible.insert(skill_name.to_string());
302 }
303 }
304 }
305
306 for entry in catalog {
310 if entry
311 .tools
312 .iter()
313 .any(|t| allowed_tool_names.contains(t.as_str()))
314 || entry
315 .providers
316 .iter()
317 .any(|p| allowed_provider_names.contains(p.as_str()))
318 || entry
319 .categories
320 .iter()
321 .any(|c| allowed_categories.contains(c))
322 {
323 visible.insert(entry.name.clone());
324 }
325 }
326
327 visible
328}
329
330async fn visible_skill_names_with_remote(
334 state: &ProxyState,
335 scopes: &ScopeConfig,
336 client: &SkillAtiClient,
337) -> Result<std::collections::HashSet<String>, SkillAtiError> {
338 let mut names = visible_skill_names(state, scopes);
339 let catalog = client.catalog().await?;
340 let remote = visible_remote_skill_names(state, scopes, &catalog);
341 names.extend(remote);
342 Ok(names)
343}
344
345async fn handle_call(
346 State(state): State<Arc<ProxyState>>,
347 req: HttpRequest<Body>,
348) -> impl IntoResponse {
349 let claims = req.extensions().get::<TokenClaims>().cloned();
351 let bearer_token: String = req
355 .extensions()
356 .get::<BearerToken>()
357 .map(|b| b.0.clone())
358 .unwrap_or_default();
359
360 let body_bytes = match axum::body::to_bytes(req.into_body(), max_call_body_bytes()).await {
367 Ok(b) => b,
368 Err(e) => {
369 return (
370 StatusCode::BAD_REQUEST,
371 Json(CallResponse {
372 result: Value::Null,
373 error: Some(format!("Failed to read request body: {e}")),
374 }),
375 );
376 }
377 };
378
379 let call_req: CallRequest = match serde_json::from_slice(&body_bytes) {
380 Ok(r) => r,
381 Err(e) => {
382 return (
383 StatusCode::UNPROCESSABLE_ENTITY,
384 Json(CallResponse {
385 result: Value::Null,
386 error: Some(format!("Invalid request: {e}")),
387 }),
388 );
389 }
390 };
391
392 tracing::debug!(
393 tool = %call_req.tool_name,
394 args = ?call_req.args,
395 "POST /call"
396 );
397
398 let (provider, tool) = match state.registry.get_tool(&call_req.tool_name) {
401 Some(pt) => pt,
402 None => {
403 let mut resolved = None;
407 for (idx, _) in call_req.tool_name.match_indices('_') {
408 let candidate = format!(
409 "{}:{}",
410 &call_req.tool_name[..idx],
411 &call_req.tool_name[idx + 1..]
412 );
413 if let Some(pt) = state.registry.get_tool(&candidate) {
414 tracing::debug!(
415 original = %call_req.tool_name,
416 resolved = %candidate,
417 "resolved underscore tool name to colon format"
418 );
419 resolved = Some(pt);
420 break;
421 }
422 }
423
424 match resolved {
425 Some(pt) => pt,
426 None => {
427 return (
428 StatusCode::NOT_FOUND,
429 Json(CallResponse {
430 result: Value::Null,
431 error: Some(format!("Unknown tool: '{}'", call_req.tool_name)),
432 }),
433 );
434 }
435 }
436 }
437 };
438
439 if let Some(tool_scope) = &tool.scope {
441 let scopes = match &claims {
442 Some(c) => ScopeConfig::from_jwt(c),
443 None if state.jwt_config.is_none() => ScopeConfig::unrestricted(), None => {
445 return (
446 StatusCode::FORBIDDEN,
447 Json(CallResponse {
448 result: Value::Null,
449 error: Some("Authentication required — no JWT provided".into()),
450 }),
451 );
452 }
453 };
454
455 if !scopes.is_allowed(tool_scope) {
456 return (
457 StatusCode::FORBIDDEN,
458 Json(CallResponse {
459 result: Value::Null,
460 error: Some(format!(
461 "Access denied: '{}' is not in your scopes",
462 tool.name
463 )),
464 }),
465 );
466 }
467 }
468
469 {
471 let scopes = match &claims {
472 Some(c) => ScopeConfig::from_jwt(c),
473 None => ScopeConfig::unrestricted(),
474 };
475 if let Some(ref rate_config) = scopes.rate_config {
476 if let Err(e) = crate::core::rate::check_and_record(&call_req.tool_name, rate_config) {
477 return (
478 StatusCode::TOO_MANY_REQUESTS,
479 Json(CallResponse {
480 result: Value::Null,
481 error: Some(format!("{e}")),
482 }),
483 );
484 }
485 }
486 }
487
488 let gen_ctx = GenContext {
490 jwt_sub: claims
491 .as_ref()
492 .map(|c| c.sub.clone())
493 .unwrap_or_else(|| "dev".into()),
494 jwt_scope: claims
495 .as_ref()
496 .map(|c| c.scope.clone())
497 .unwrap_or_else(|| "*".into()),
498 tool_name: call_req.tool_name.clone(),
499 timestamp: crate::core::jwt::now_secs(),
500 jwt_token: bearer_token.clone(),
501 };
502
503 let agent_sub = claims.as_ref().map(|c| c.sub.clone()).unwrap_or_default();
505 let job_id = claims
506 .as_ref()
507 .and_then(|c| c.job_id.clone())
508 .unwrap_or_default();
509 let sandbox_id = claims
510 .as_ref()
511 .and_then(|c| c.sandbox_id.clone())
512 .unwrap_or_default();
513 tracing::info!(
514 tool = %call_req.tool_name,
515 agent = %agent_sub,
516 job_id = %job_id,
517 sandbox_id = %sandbox_id,
518 "tool call"
519 );
520 let start = std::time::Instant::now();
521
522 let response = match provider.handler.as_str() {
523 "mcp" => {
524 let args_map = call_req.args_as_map();
525 match mcp_client::execute_with_gen(
526 provider,
527 &call_req.tool_name,
528 &args_map,
529 &state.keyring,
530 Some(&gen_ctx),
531 Some(&state.auth_cache),
532 )
533 .await
534 {
535 Ok(result) => (
536 StatusCode::OK,
537 Json(CallResponse {
538 result,
539 error: None,
540 }),
541 ),
542 Err(e) => {
543 let (provider_name, operation_id) =
544 sentry_scope::split_tool_name(&call_req.tool_name);
545 sentry_scope::report_upstream_error(
546 &provider_name,
547 &operation_id,
548 0,
549 502,
550 None,
551 Some(&e.to_string()),
552 );
553 (
554 StatusCode::BAD_GATEWAY,
555 Json(CallResponse {
556 result: Value::Null,
557 error: Some(format!("MCP error: {e}")),
558 }),
559 )
560 }
561 }
562 }
563 "cli" => {
564 let positional = call_req.args_as_positional();
565 match crate::core::cli_executor::execute_with_gen(
566 provider,
567 &positional,
568 &state.keyring,
569 Some(&gen_ctx),
570 Some(&state.auth_cache),
571 )
572 .await
573 {
574 Ok(result) => (
575 StatusCode::OK,
576 Json(CallResponse {
577 result,
578 error: None,
579 }),
580 ),
581 Err(e) => {
582 let (provider_name, operation_id) =
583 sentry_scope::split_tool_name(&call_req.tool_name);
584 sentry_scope::report_upstream_error(
585 &provider_name,
586 &operation_id,
587 0,
588 502,
589 None,
590 Some(&e.to_string()),
591 );
592 (
593 StatusCode::BAD_GATEWAY,
594 Json(CallResponse {
595 result: Value::Null,
596 error: Some(format!("CLI error: {e}")),
597 }),
598 )
599 }
600 }
601 }
602 "file_manager" => {
603 let args_map = call_req.args_as_map();
604 match dispatch_file_manager(&call_req.tool_name, &args_map, provider, &state.keyring)
605 .await
606 {
607 Ok(result) => (
608 StatusCode::OK,
609 Json(CallResponse {
610 result,
611 error: None,
612 }),
613 ),
614 Err((status, msg)) => (
615 status,
616 Json(CallResponse {
617 result: Value::Null,
618 error: Some(msg),
619 }),
620 ),
621 }
622 }
623 _ => {
624 let args_map = call_req.args_as_map();
625 let raw_response = match http::execute_tool_with_gen(
626 provider,
627 tool,
628 &args_map,
629 &state.keyring,
630 Some(&gen_ctx),
631 Some(&state.auth_cache),
632 )
633 .await
634 {
635 Ok(resp) => resp,
636 Err(http::HttpError::NoRecordsFound { status }) => {
637 let duration = start.elapsed();
641 tracing::info!(
642 tool = %call_req.tool_name,
643 upstream_status = status,
644 "upstream returned no records"
645 );
646 write_proxy_audit(&call_req, &agent_sub, claims.as_ref(), duration, None);
647 return (
648 StatusCode::OK,
649 Json(CallResponse {
650 result: serde_json::json!({ "records": [] }),
651 error: None,
652 }),
653 );
654 }
655 Err(e) => {
656 let duration = start.elapsed();
657 let (provider_name, operation_id) =
658 sentry_scope::split_tool_name(&call_req.tool_name);
659 let (upstream_status, error_type, error_message) = match &e {
660 http::HttpError::ApiError {
661 status,
662 error_type,
663 error_message,
664 ..
665 } => (*status, error_type.clone(), error_message.clone()),
666 _ => (0u16, None, Some(e.to_string())),
667 };
668 sentry_scope::report_upstream_error(
669 &provider_name,
670 &operation_id,
671 upstream_status,
672 502,
673 error_type.as_deref(),
674 error_message.as_deref(),
675 );
676 write_proxy_audit(
677 &call_req,
678 &agent_sub,
679 claims.as_ref(),
680 duration,
681 Some(&e.to_string()),
682 );
683 return (
684 StatusCode::BAD_GATEWAY,
685 Json(CallResponse {
686 result: Value::Null,
687 error: Some(format!("Upstream API error: {e}")),
688 }),
689 );
690 }
691 };
692
693 let processed = match response::process_response(&raw_response, tool.response.as_ref())
694 {
695 Ok(p) => p,
696 Err(e) => {
697 let duration = start.elapsed();
698 write_proxy_audit(
699 &call_req,
700 &agent_sub,
701 claims.as_ref(),
702 duration,
703 Some(&e.to_string()),
704 );
705 return (
706 StatusCode::INTERNAL_SERVER_ERROR,
707 Json(CallResponse {
708 result: raw_response,
709 error: Some(format!("Response processing error: {e}")),
710 }),
711 );
712 }
713 };
714
715 (
716 StatusCode::OK,
717 Json(CallResponse {
718 result: processed,
719 error: None,
720 }),
721 )
722 }
723 };
724
725 let duration = start.elapsed();
726 let error_msg = response.1.error.as_deref();
727 write_proxy_audit(&call_req, &agent_sub, claims.as_ref(), duration, error_msg);
728
729 response
730}
731
732async fn handle_help(
733 State(state): State<Arc<ProxyState>>,
734 claims: Option<Extension<TokenClaims>>,
735 Json(req): Json<HelpRequest>,
736) -> impl IntoResponse {
737 tracing::debug!(query = %req.query, tool = ?req.tool, "POST /help");
738
739 let claims = claims.map(|Extension(claims)| claims);
740 let scopes = scopes_for_request(claims.as_ref(), &state);
741
742 let (llm_provider, llm_tool) = match state.registry.get_tool("_chat_completion") {
743 Some(pt) => pt,
744 None => {
745 return (
746 StatusCode::SERVICE_UNAVAILABLE,
747 Json(HelpResponse {
748 content: String::new(),
749 error: Some("No _llm.toml manifest found. Proxy help requires a configured LLM provider.".into()),
750 }),
751 );
752 }
753 };
754
755 let api_key = match llm_provider
756 .auth_key_name
757 .as_deref()
758 .and_then(|k| state.keyring.get(k))
759 {
760 Some(key) => key.to_string(),
761 None => {
762 return (
763 StatusCode::SERVICE_UNAVAILABLE,
764 Json(HelpResponse {
765 content: String::new(),
766 error: Some("LLM API key not found in keyring".into()),
767 }),
768 );
769 }
770 };
771
772 let resolved_skills = skill::resolve_skills(&state.skill_registry, &state.registry, &scopes);
773 let local_skills_section = if resolved_skills.is_empty() {
774 String::new()
775 } else {
776 format!(
777 "## Available Skills (methodology guides)\n{}",
778 skill::build_skill_context(&resolved_skills)
779 )
780 };
781 let remote_query = req
782 .tool
783 .as_ref()
784 .map(|tool| format!("{tool} {}", req.query))
785 .unwrap_or_else(|| req.query.clone());
786 let remote_skills_section =
787 build_remote_skillati_section(&state.keyring, &remote_query, 12).await;
788 let skills_section = merge_help_skill_sections(&[local_skills_section, remote_skills_section]);
789
790 let visible_tools = visible_tools_for_scopes(&state, &scopes);
792 let system_prompt = if let Some(ref tool_name) = req.tool {
793 match build_scoped_prompt(tool_name, &visible_tools, &skills_section) {
795 Some(prompt) => prompt,
796 None => {
797 return (
798 StatusCode::FORBIDDEN,
799 Json(HelpResponse {
800 content: String::new(),
801 error: Some(format!(
802 "Scope '{tool_name}' is not visible in your current scopes."
803 )),
804 }),
805 );
806 }
807 }
808 } else {
809 let tools_context = build_tool_context(&visible_tools);
810 HELP_SYSTEM_PROMPT
811 .replace("{tools}", &tools_context)
812 .replace("{skills_section}", &skills_section)
813 };
814
815 let request_body = serde_json::json!({
816 "model": "zai-glm-4.7",
817 "messages": [
818 {"role": "system", "content": system_prompt},
819 {"role": "user", "content": req.query}
820 ],
821 "max_completion_tokens": 1536,
822 "temperature": 0.3
823 });
824
825 let client = reqwest::Client::new();
826 let url = format!(
827 "{}{}",
828 llm_provider.base_url.trim_end_matches('/'),
829 llm_tool.endpoint
830 );
831
832 let response = match client
833 .post(&url)
834 .bearer_auth(&api_key)
835 .json(&request_body)
836 .send()
837 .await
838 {
839 Ok(r) => r,
840 Err(e) => {
841 return (
842 StatusCode::BAD_GATEWAY,
843 Json(HelpResponse {
844 content: String::new(),
845 error: Some(format!("LLM request failed: {e}")),
846 }),
847 );
848 }
849 };
850
851 if !response.status().is_success() {
852 let status = response.status();
853 let body = response.text().await.unwrap_or_default();
854 return (
855 StatusCode::BAD_GATEWAY,
856 Json(HelpResponse {
857 content: String::new(),
858 error: Some(format!("LLM API error ({status}): {body}")),
859 }),
860 );
861 }
862
863 let body: Value = match response.json().await {
864 Ok(b) => b,
865 Err(e) => {
866 return (
867 StatusCode::INTERNAL_SERVER_ERROR,
868 Json(HelpResponse {
869 content: String::new(),
870 error: Some(format!("Failed to parse LLM response: {e}")),
871 }),
872 );
873 }
874 };
875
876 let content = body
877 .pointer("/choices/0/message/content")
878 .and_then(|c| c.as_str())
879 .unwrap_or("No response from LLM")
880 .to_string();
881
882 (
883 StatusCode::OK,
884 Json(HelpResponse {
885 content,
886 error: None,
887 }),
888 )
889}
890
891async fn handle_health(State(state): State<Arc<ProxyState>>) -> impl IntoResponse {
892 let auth = if state.jwt_config.is_some() {
893 "jwt"
894 } else {
895 "disabled"
896 };
897
898 Json(HealthResponse {
899 status: "ok".into(),
900 version: env!("CARGO_PKG_VERSION").into(),
901 tools: state.registry.list_public_tools().len(),
902 providers: state.registry.list_providers().len(),
903 skills: state.skill_registry.skill_count(),
904 auth: auth.into(),
905 })
906}
907
908async fn handle_jwks(State(state): State<Arc<ProxyState>>) -> impl IntoResponse {
910 match &state.jwks_json {
911 Some(jwks) => (StatusCode::OK, Json(jwks.clone())),
912 None => (
913 StatusCode::NOT_FOUND,
914 Json(serde_json::json!({"error": "JWKS not configured"})),
915 ),
916 }
917}
918
919async fn handle_mcp(
924 State(state): State<Arc<ProxyState>>,
925 claims: Option<Extension<TokenClaims>>,
926 bearer: Option<Extension<BearerToken>>,
927 Json(msg): Json<Value>,
928) -> impl IntoResponse {
929 let claims = claims.map(|Extension(claims)| claims);
930 let bearer_token: String = bearer.map(|Extension(b)| b.0).unwrap_or_default();
934 let scopes = scopes_for_request(claims.as_ref(), &state);
935 let method = msg.get("method").and_then(|m| m.as_str()).unwrap_or("");
936 let id = msg.get("id").cloned();
937 tracing::info!(
938 %method,
939 agent = claims.as_ref().map(|c| c.sub.as_str()).unwrap_or(""),
940 job_id = claims.as_ref().and_then(|c| c.job_id.as_deref()).unwrap_or(""),
941 sandbox_id = claims.as_ref().and_then(|c| c.sandbox_id.as_deref()).unwrap_or(""),
942 "mcp call"
943 );
944
945 match method {
946 "initialize" => {
947 let result = serde_json::json!({
948 "protocolVersion": "2025-03-26",
949 "capabilities": {
950 "tools": { "listChanged": false }
951 },
952 "serverInfo": {
953 "name": "ati-proxy",
954 "version": env!("CARGO_PKG_VERSION")
955 }
956 });
957 jsonrpc_success(id, result)
958 }
959
960 "notifications/initialized" => (StatusCode::ACCEPTED, Json(Value::Null)),
961
962 "tools/list" => {
963 let visible_tools = visible_tools_for_scopes(&state, &scopes);
964 let mcp_tools: Vec<Value> = visible_tools
965 .iter()
966 .map(|(_provider, tool)| {
967 serde_json::json!({
968 "name": tool.name,
969 "description": tool.description,
970 "inputSchema": tool.input_schema.clone().unwrap_or(serde_json::json!({
971 "type": "object",
972 "properties": {}
973 }))
974 })
975 })
976 .collect();
977
978 let result = serde_json::json!({
979 "tools": mcp_tools,
980 });
981 jsonrpc_success(id, result)
982 }
983
984 "tools/call" => {
985 let params = msg.get("params").cloned().unwrap_or(Value::Null);
986 let tool_name = params.get("name").and_then(|n| n.as_str()).unwrap_or("");
987 let arguments: HashMap<String, Value> = params
988 .get("arguments")
989 .and_then(|a| serde_json::from_value(a.clone()).ok())
990 .unwrap_or_default();
991
992 if tool_name.is_empty() {
993 return jsonrpc_error(id, -32602, "Missing tool name in params.name");
994 }
995
996 let (provider, _tool) = match state.registry.get_tool(tool_name) {
997 Some(pt) => pt,
998 None => {
999 return jsonrpc_error(id, -32602, &format!("Unknown tool: '{tool_name}'"));
1000 }
1001 };
1002
1003 if let Some(tool_scope) = &_tool.scope {
1004 if !scopes.is_allowed(tool_scope) {
1005 return jsonrpc_error(
1006 id,
1007 -32001,
1008 &format!("Access denied: '{}' is not in your scopes", _tool.name),
1009 );
1010 }
1011 }
1012
1013 tracing::debug!(%tool_name, provider = %provider.name, "MCP tools/call");
1014
1015 let mcp_gen_ctx = GenContext {
1016 jwt_sub: claims
1017 .as_ref()
1018 .map(|claims| claims.sub.clone())
1019 .unwrap_or_else(|| "dev".into()),
1020 jwt_scope: claims
1021 .as_ref()
1022 .map(|claims| claims.scope.clone())
1023 .unwrap_or_else(|| "*".into()),
1024 tool_name: tool_name.to_string(),
1025 timestamp: crate::core::jwt::now_secs(),
1026 jwt_token: bearer_token.clone(),
1027 };
1028
1029 let result = if provider.is_mcp() {
1030 mcp_client::execute_with_gen(
1031 provider,
1032 tool_name,
1033 &arguments,
1034 &state.keyring,
1035 Some(&mcp_gen_ctx),
1036 Some(&state.auth_cache),
1037 )
1038 .await
1039 } else if provider.is_cli() {
1040 let raw: Vec<String> = arguments
1042 .iter()
1043 .flat_map(|(k, v)| {
1044 let val = match v {
1045 Value::String(s) => s.clone(),
1046 other => other.to_string(),
1047 };
1048 vec![format!("--{k}"), val]
1049 })
1050 .collect();
1051 crate::core::cli_executor::execute_with_gen(
1052 provider,
1053 &raw,
1054 &state.keyring,
1055 Some(&mcp_gen_ctx),
1056 Some(&state.auth_cache),
1057 )
1058 .await
1059 .map_err(|e| mcp_client::McpError::Transport(e.to_string()))
1060 } else {
1061 match http::execute_tool_with_gen(
1062 provider,
1063 _tool,
1064 &arguments,
1065 &state.keyring,
1066 Some(&mcp_gen_ctx),
1067 Some(&state.auth_cache),
1068 )
1069 .await
1070 {
1071 Ok(val) => Ok(val),
1072 Err(e) => Err(mcp_client::McpError::Transport(e.to_string())),
1073 }
1074 };
1075
1076 match result {
1077 Ok(value) => {
1078 let text = match &value {
1079 Value::String(s) => s.clone(),
1080 other => serde_json::to_string_pretty(other).unwrap_or_default(),
1081 };
1082 let mcp_result = serde_json::json!({
1083 "content": [{"type": "text", "text": text}],
1084 "isError": false,
1085 });
1086 jsonrpc_success(id, mcp_result)
1087 }
1088 Err(e) => {
1089 let mcp_result = serde_json::json!({
1090 "content": [{"type": "text", "text": format!("Error: {e}")}],
1091 "isError": true,
1092 });
1093 jsonrpc_success(id, mcp_result)
1094 }
1095 }
1096 }
1097
1098 _ => jsonrpc_error(id, -32601, &format!("Method not found: '{method}'")),
1099 }
1100}
1101
1102fn jsonrpc_success(id: Option<Value>, result: Value) -> (StatusCode, Json<Value>) {
1103 (
1104 StatusCode::OK,
1105 Json(serde_json::json!({
1106 "jsonrpc": "2.0",
1107 "id": id,
1108 "result": result,
1109 })),
1110 )
1111}
1112
1113fn jsonrpc_error(id: Option<Value>, code: i64, message: &str) -> (StatusCode, Json<Value>) {
1114 (
1115 StatusCode::OK,
1116 Json(serde_json::json!({
1117 "jsonrpc": "2.0",
1118 "id": id,
1119 "error": {
1120 "code": code,
1121 "message": message,
1122 }
1123 })),
1124 )
1125}
1126
1127async fn handle_tools_list(
1133 State(state): State<Arc<ProxyState>>,
1134 claims: Option<Extension<TokenClaims>>,
1135 axum::extract::Query(query): axum::extract::Query<ToolsQuery>,
1136) -> impl IntoResponse {
1137 tracing::debug!(
1138 provider = ?query.provider,
1139 search = ?query.search,
1140 "GET /tools"
1141 );
1142
1143 let claims = claims.map(|Extension(claims)| claims);
1144 let scopes = scopes_for_request(claims.as_ref(), &state);
1145 let all_tools = visible_tools_for_scopes(&state, &scopes);
1146
1147 let tools: Vec<Value> = all_tools
1148 .iter()
1149 .filter(|(provider, tool)| {
1150 if let Some(ref p) = query.provider {
1151 if provider.name != *p {
1152 return false;
1153 }
1154 }
1155 if let Some(ref q) = query.search {
1156 let q = q.to_lowercase();
1157 let name_match = tool.name.to_lowercase().contains(&q);
1158 let desc_match = tool.description.to_lowercase().contains(&q);
1159 let tag_match = tool.tags.iter().any(|t| t.to_lowercase().contains(&q));
1160 if !name_match && !desc_match && !tag_match {
1161 return false;
1162 }
1163 }
1164 true
1165 })
1166 .map(|(provider, tool)| {
1167 serde_json::json!({
1168 "name": tool.name,
1169 "description": tool.description,
1170 "provider": provider.name,
1171 "method": format!("{:?}", tool.method),
1172 "tags": tool.tags,
1173 "skills": provider.skills,
1174 "input_schema": tool.input_schema,
1175 })
1176 })
1177 .collect();
1178
1179 (StatusCode::OK, Json(Value::Array(tools)))
1180}
1181
1182async fn handle_tool_info(
1184 State(state): State<Arc<ProxyState>>,
1185 claims: Option<Extension<TokenClaims>>,
1186 axum::extract::Path(name): axum::extract::Path<String>,
1187) -> impl IntoResponse {
1188 tracing::debug!(tool = %name, "GET /tools/:name");
1189
1190 let claims = claims.map(|Extension(claims)| claims);
1191 let scopes = scopes_for_request(claims.as_ref(), &state);
1192
1193 match state
1194 .registry
1195 .get_tool(&name)
1196 .filter(|(_, tool)| match &tool.scope {
1197 Some(scope) => scopes.is_allowed(scope),
1198 None => true,
1199 }) {
1200 Some((provider, tool)) => {
1201 let mut skills: Vec<String> = provider.skills.clone();
1203 for s in state.skill_registry.skills_for_tool(&tool.name) {
1204 if !skills.contains(&s.name) {
1205 skills.push(s.name.clone());
1206 }
1207 }
1208 for s in state.skill_registry.skills_for_provider(&provider.name) {
1209 if !skills.contains(&s.name) {
1210 skills.push(s.name.clone());
1211 }
1212 }
1213
1214 (
1215 StatusCode::OK,
1216 Json(serde_json::json!({
1217 "name": tool.name,
1218 "description": tool.description,
1219 "provider": provider.name,
1220 "method": format!("{:?}", tool.method),
1221 "endpoint": tool.endpoint,
1222 "tags": tool.tags,
1223 "hint": tool.hint,
1224 "skills": skills,
1225 "input_schema": tool.input_schema,
1226 "scope": tool.scope,
1227 })),
1228 )
1229 }
1230 None => (
1231 StatusCode::NOT_FOUND,
1232 Json(serde_json::json!({"error": format!("Tool '{name}' not found")})),
1233 ),
1234 }
1235}
1236
1237async fn handle_skills_list(
1242 State(state): State<Arc<ProxyState>>,
1243 claims: Option<Extension<TokenClaims>>,
1244 axum::extract::Query(query): axum::extract::Query<SkillsQuery>,
1245) -> impl IntoResponse {
1246 tracing::debug!(
1247 category = ?query.category,
1248 provider = ?query.provider,
1249 tool = ?query.tool,
1250 search = ?query.search,
1251 "GET /skills"
1252 );
1253
1254 let claims = claims.map(|Extension(claims)| claims);
1255 let scopes = scopes_for_request(claims.as_ref(), &state);
1256 let visible_names = visible_skill_names(&state, &scopes);
1257
1258 let skills: Vec<&skill::SkillMeta> = if let Some(search_query) = &query.search {
1259 state
1260 .skill_registry
1261 .search(search_query)
1262 .into_iter()
1263 .filter(|skill| visible_names.contains(&skill.name))
1264 .collect()
1265 } else if let Some(cat) = &query.category {
1266 state
1267 .skill_registry
1268 .skills_for_category(cat)
1269 .into_iter()
1270 .filter(|skill| visible_names.contains(&skill.name))
1271 .collect()
1272 } else if let Some(prov) = &query.provider {
1273 state
1274 .skill_registry
1275 .skills_for_provider(prov)
1276 .into_iter()
1277 .filter(|skill| visible_names.contains(&skill.name))
1278 .collect()
1279 } else if let Some(t) = &query.tool {
1280 state
1281 .skill_registry
1282 .skills_for_tool(t)
1283 .into_iter()
1284 .filter(|skill| visible_names.contains(&skill.name))
1285 .collect()
1286 } else {
1287 state
1288 .skill_registry
1289 .list_skills()
1290 .iter()
1291 .filter(|skill| visible_names.contains(&skill.name))
1292 .collect()
1293 };
1294
1295 let json: Vec<Value> = skills
1296 .iter()
1297 .map(|s| {
1298 serde_json::json!({
1299 "name": s.name,
1300 "version": s.version,
1301 "description": s.description,
1302 "tools": s.tools,
1303 "providers": s.providers,
1304 "categories": s.categories,
1305 "hint": s.hint,
1306 })
1307 })
1308 .collect();
1309
1310 (StatusCode::OK, Json(Value::Array(json)))
1311}
1312
1313async fn handle_skill_detail(
1314 State(state): State<Arc<ProxyState>>,
1315 claims: Option<Extension<TokenClaims>>,
1316 axum::extract::Path(name): axum::extract::Path<String>,
1317 axum::extract::Query(query): axum::extract::Query<SkillDetailQuery>,
1318) -> impl IntoResponse {
1319 tracing::debug!(%name, meta = ?query.meta, refs = ?query.refs, "GET /skills/:name");
1320
1321 let claims = claims.map(|Extension(claims)| claims);
1322 let scopes = scopes_for_request(claims.as_ref(), &state);
1323 let visible_names = visible_skill_names(&state, &scopes);
1324
1325 let skill_meta = match state
1326 .skill_registry
1327 .get_skill(&name)
1328 .filter(|skill| visible_names.contains(&skill.name))
1329 {
1330 Some(s) => s,
1331 None => {
1332 return (
1333 StatusCode::NOT_FOUND,
1334 Json(serde_json::json!({"error": format!("Skill '{name}' not found")})),
1335 );
1336 }
1337 };
1338
1339 if query.meta.unwrap_or(false) {
1340 return (
1341 StatusCode::OK,
1342 Json(serde_json::json!({
1343 "name": skill_meta.name,
1344 "version": skill_meta.version,
1345 "description": skill_meta.description,
1346 "author": skill_meta.author,
1347 "tools": skill_meta.tools,
1348 "providers": skill_meta.providers,
1349 "categories": skill_meta.categories,
1350 "keywords": skill_meta.keywords,
1351 "hint": skill_meta.hint,
1352 "depends_on": skill_meta.depends_on,
1353 "suggests": skill_meta.suggests,
1354 "license": skill_meta.license,
1355 "compatibility": skill_meta.compatibility,
1356 "allowed_tools": skill_meta.allowed_tools,
1357 "format": skill_meta.format,
1358 })),
1359 );
1360 }
1361
1362 let content = match state.skill_registry.read_content(&name) {
1363 Ok(c) => c,
1364 Err(e) => {
1365 return (
1366 StatusCode::INTERNAL_SERVER_ERROR,
1367 Json(serde_json::json!({"error": format!("Failed to read skill: {e}")})),
1368 );
1369 }
1370 };
1371
1372 let mut response = serde_json::json!({
1373 "name": skill_meta.name,
1374 "version": skill_meta.version,
1375 "description": skill_meta.description,
1376 "content": content,
1377 });
1378
1379 if query.refs.unwrap_or(false) {
1380 if let Ok(refs) = state.skill_registry.list_references(&name) {
1381 response["references"] = serde_json::json!(refs);
1382 }
1383 }
1384
1385 (StatusCode::OK, Json(response))
1386}
1387
1388async fn handle_skill_bundle(
1392 State(state): State<Arc<ProxyState>>,
1393 claims: Option<Extension<TokenClaims>>,
1394 axum::extract::Path(name): axum::extract::Path<String>,
1395) -> impl IntoResponse {
1396 tracing::debug!(skill = %name, "GET /skills/:name/bundle");
1397
1398 let claims = claims.map(|Extension(claims)| claims);
1399 let scopes = scopes_for_request(claims.as_ref(), &state);
1400 let visible_names = visible_skill_names(&state, &scopes);
1401 if !visible_names.contains(&name) {
1402 return (
1403 StatusCode::NOT_FOUND,
1404 Json(serde_json::json!({"error": format!("Skill '{name}' not found")})),
1405 );
1406 }
1407
1408 let files = match state.skill_registry.bundle_files(&name) {
1409 Ok(f) => f,
1410 Err(_) => {
1411 return (
1412 StatusCode::NOT_FOUND,
1413 Json(serde_json::json!({"error": format!("Skill '{name}' not found")})),
1414 );
1415 }
1416 };
1417
1418 let mut file_map = serde_json::Map::new();
1420 for (path, data) in &files {
1421 match std::str::from_utf8(data) {
1422 Ok(text) => {
1423 file_map.insert(path.clone(), Value::String(text.to_string()));
1424 }
1425 Err(_) => {
1426 use base64::Engine;
1428 let encoded = base64::engine::general_purpose::STANDARD.encode(data);
1429 file_map.insert(path.clone(), serde_json::json!({"base64": encoded}));
1430 }
1431 }
1432 }
1433
1434 (
1435 StatusCode::OK,
1436 Json(serde_json::json!({
1437 "name": name,
1438 "files": file_map,
1439 })),
1440 )
1441}
1442
1443async fn handle_skills_bundle_batch(
1447 State(state): State<Arc<ProxyState>>,
1448 claims: Option<Extension<TokenClaims>>,
1449 Json(req): Json<SkillBundleBatchRequest>,
1450) -> impl IntoResponse {
1451 const MAX_BATCH: usize = 50;
1452 if req.names.len() > MAX_BATCH {
1453 return (
1454 StatusCode::BAD_REQUEST,
1455 Json(
1456 serde_json::json!({"error": format!("batch size {} exceeds limit of {MAX_BATCH}", req.names.len())}),
1457 ),
1458 );
1459 }
1460
1461 tracing::debug!(names = ?req.names, "POST /skills/bundle");
1462
1463 let claims = claims.map(|Extension(claims)| claims);
1464 let scopes = scopes_for_request(claims.as_ref(), &state);
1465 let visible_names = visible_skill_names(&state, &scopes);
1466
1467 let mut result = serde_json::Map::new();
1468 let mut missing: Vec<String> = Vec::new();
1469
1470 for name in &req.names {
1471 if !visible_names.contains(name) {
1472 missing.push(name.clone());
1473 continue;
1474 }
1475 let files = match state.skill_registry.bundle_files(name) {
1476 Ok(f) => f,
1477 Err(_) => {
1478 missing.push(name.clone());
1479 continue;
1480 }
1481 };
1482
1483 let mut file_map = serde_json::Map::new();
1484 for (path, data) in &files {
1485 match std::str::from_utf8(data) {
1486 Ok(text) => {
1487 file_map.insert(path.clone(), Value::String(text.to_string()));
1488 }
1489 Err(_) => {
1490 use base64::Engine;
1491 let encoded = base64::engine::general_purpose::STANDARD.encode(data);
1492 file_map.insert(path.clone(), serde_json::json!({"base64": encoded}));
1493 }
1494 }
1495 }
1496
1497 result.insert(name.clone(), serde_json::json!({ "files": file_map }));
1498 }
1499
1500 (
1501 StatusCode::OK,
1502 Json(serde_json::json!({ "skills": result, "missing": missing })),
1503 )
1504}
1505
1506async fn handle_skills_resolve(
1507 State(state): State<Arc<ProxyState>>,
1508 claims: Option<Extension<TokenClaims>>,
1509 Json(req): Json<SkillResolveRequest>,
1510) -> impl IntoResponse {
1511 tracing::debug!(scopes = ?req.scopes, include_content = req.include_content, "POST /skills/resolve");
1512
1513 let include_content = req.include_content;
1514 let request_scopes = ScopeConfig {
1515 scopes: req.scopes,
1516 sub: String::new(),
1517 expires_at: 0,
1518 rate_config: None,
1519 };
1520 let claims = claims.map(|Extension(claims)| claims);
1521 let caller_scopes = scopes_for_request(claims.as_ref(), &state);
1522 let visible_names = visible_skill_names(&state, &caller_scopes);
1523
1524 let resolved: Vec<&skill::SkillMeta> =
1525 skill::resolve_skills(&state.skill_registry, &state.registry, &request_scopes)
1526 .into_iter()
1527 .filter(|skill| visible_names.contains(&skill.name))
1528 .collect();
1529
1530 let json: Vec<Value> = resolved
1531 .iter()
1532 .map(|s| {
1533 let mut entry = serde_json::json!({
1534 "name": s.name,
1535 "version": s.version,
1536 "description": s.description,
1537 "tools": s.tools,
1538 "providers": s.providers,
1539 "categories": s.categories,
1540 });
1541 if include_content {
1542 if let Ok(content) = state.skill_registry.read_content(&s.name) {
1543 entry["content"] = Value::String(content);
1544 }
1545 }
1546 entry
1547 })
1548 .collect();
1549
1550 (StatusCode::OK, Json(Value::Array(json)))
1551}
1552
1553fn skillati_client(keyring: &Keyring) -> Result<SkillAtiClient, SkillAtiError> {
1554 match SkillAtiClient::from_env(keyring)? {
1555 Some(client) => Ok(client),
1556 None => Err(SkillAtiError::NotConfigured),
1557 }
1558}
1559
1560async fn handle_skillati_catalog(
1561 State(state): State<Arc<ProxyState>>,
1562 claims: Option<Extension<TokenClaims>>,
1563 Query(query): Query<SkillAtiCatalogQuery>,
1564) -> impl IntoResponse {
1565 tracing::debug!(search = ?query.search, "GET /skillati/catalog");
1566
1567 let client = match skillati_client(&state.keyring) {
1568 Ok(client) => client,
1569 Err(err) => return skillati_error_response(err),
1570 };
1571
1572 let claims = claims.map(|Extension(c)| c);
1573 let scopes = scopes_for_request(claims.as_ref(), &state);
1574
1575 match client.catalog().await {
1576 Ok(catalog) => {
1577 let mut visible_names = visible_skill_names(&state, &scopes);
1581 visible_names.extend(visible_remote_skill_names(&state, &scopes, &catalog));
1582
1583 let mut skills: Vec<_> = catalog
1584 .into_iter()
1585 .filter(|s| visible_names.contains(&s.name))
1586 .collect();
1587 if let Some(search) = query.search.as_deref() {
1588 skills = SkillAtiClient::filter_catalog(&skills, search, 25);
1589 }
1590 (
1591 StatusCode::OK,
1592 Json(serde_json::json!({
1593 "skills": skills,
1594 })),
1595 )
1596 }
1597 Err(err) => skillati_error_response(err),
1598 }
1599}
1600
1601async fn handle_skillati_read(
1602 State(state): State<Arc<ProxyState>>,
1603 claims: Option<Extension<TokenClaims>>,
1604 axum::extract::Path(name): axum::extract::Path<String>,
1605) -> impl IntoResponse {
1606 tracing::debug!(%name, "GET /skillati/:name");
1607
1608 let client = match skillati_client(&state.keyring) {
1609 Ok(client) => client,
1610 Err(err) => return skillati_error_response(err),
1611 };
1612
1613 let claims = claims.map(|Extension(c)| c);
1614 let scopes = scopes_for_request(claims.as_ref(), &state);
1615 let visible_names = match visible_skill_names_with_remote(&state, &scopes, &client).await {
1616 Ok(v) => v,
1617 Err(err) => return skillati_error_response(err),
1618 };
1619 if !visible_names.contains(&name) {
1620 return skillati_error_response(SkillAtiError::SkillNotFound(name));
1621 }
1622
1623 match client.read_skill(&name).await {
1624 Ok(activation) => (StatusCode::OK, Json(serde_json::json!(activation))),
1625 Err(err) => skillati_error_response(err),
1626 }
1627}
1628
1629async fn handle_skillati_resources(
1630 State(state): State<Arc<ProxyState>>,
1631 claims: Option<Extension<TokenClaims>>,
1632 axum::extract::Path(name): axum::extract::Path<String>,
1633 Query(query): Query<SkillAtiResourcesQuery>,
1634) -> impl IntoResponse {
1635 tracing::debug!(%name, prefix = ?query.prefix, "GET /skillati/:name/resources");
1636
1637 let client = match skillati_client(&state.keyring) {
1638 Ok(client) => client,
1639 Err(err) => return skillati_error_response(err),
1640 };
1641
1642 let claims = claims.map(|Extension(c)| c);
1643 let scopes = scopes_for_request(claims.as_ref(), &state);
1644 let visible_names = match visible_skill_names_with_remote(&state, &scopes, &client).await {
1645 Ok(v) => v,
1646 Err(err) => return skillati_error_response(err),
1647 };
1648 if !visible_names.contains(&name) {
1649 return skillati_error_response(SkillAtiError::SkillNotFound(name));
1650 }
1651
1652 match client.list_resources(&name, query.prefix.as_deref()).await {
1653 Ok(resources) => (
1654 StatusCode::OK,
1655 Json(serde_json::json!({
1656 "name": name,
1657 "prefix": query.prefix,
1658 "resources": resources,
1659 })),
1660 ),
1661 Err(err) => skillati_error_response(err),
1662 }
1663}
1664
1665async fn handle_skillati_file(
1666 State(state): State<Arc<ProxyState>>,
1667 claims: Option<Extension<TokenClaims>>,
1668 axum::extract::Path(name): axum::extract::Path<String>,
1669 Query(query): Query<SkillAtiFileQuery>,
1670) -> impl IntoResponse {
1671 tracing::debug!(%name, path = %query.path, "GET /skillati/:name/file");
1672
1673 let client = match skillati_client(&state.keyring) {
1674 Ok(client) => client,
1675 Err(err) => return skillati_error_response(err),
1676 };
1677
1678 let claims = claims.map(|Extension(c)| c);
1679 let scopes = scopes_for_request(claims.as_ref(), &state);
1680 let visible_names = match visible_skill_names_with_remote(&state, &scopes, &client).await {
1681 Ok(v) => v,
1682 Err(err) => return skillati_error_response(err),
1683 };
1684 if !visible_names.contains(&name) {
1685 return skillati_error_response(SkillAtiError::SkillNotFound(name));
1686 }
1687
1688 match client.read_path(&name, &query.path).await {
1689 Ok(file) => (StatusCode::OK, Json(serde_json::json!(file))),
1690 Err(err) => skillati_error_response(err),
1691 }
1692}
1693
1694async fn handle_skillati_refs(
1695 State(state): State<Arc<ProxyState>>,
1696 claims: Option<Extension<TokenClaims>>,
1697 axum::extract::Path(name): axum::extract::Path<String>,
1698) -> impl IntoResponse {
1699 tracing::debug!(%name, "GET /skillati/:name/refs");
1700
1701 let client = match skillati_client(&state.keyring) {
1702 Ok(client) => client,
1703 Err(err) => return skillati_error_response(err),
1704 };
1705
1706 let claims = claims.map(|Extension(c)| c);
1707 let scopes = scopes_for_request(claims.as_ref(), &state);
1708 let visible_names = match visible_skill_names_with_remote(&state, &scopes, &client).await {
1709 Ok(v) => v,
1710 Err(err) => return skillati_error_response(err),
1711 };
1712 if !visible_names.contains(&name) {
1713 return skillati_error_response(SkillAtiError::SkillNotFound(name));
1714 }
1715
1716 match client.list_references(&name).await {
1717 Ok(references) => (
1718 StatusCode::OK,
1719 Json(serde_json::json!({
1720 "name": name,
1721 "references": references,
1722 })),
1723 ),
1724 Err(err) => skillati_error_response(err),
1725 }
1726}
1727
1728async fn handle_skillati_ref(
1729 State(state): State<Arc<ProxyState>>,
1730 claims: Option<Extension<TokenClaims>>,
1731 axum::extract::Path((name, reference)): axum::extract::Path<(String, String)>,
1732) -> impl IntoResponse {
1733 tracing::debug!(%name, %reference, "GET /skillati/:name/ref/:reference");
1734
1735 let client = match skillati_client(&state.keyring) {
1736 Ok(client) => client,
1737 Err(err) => return skillati_error_response(err),
1738 };
1739
1740 let claims = claims.map(|Extension(c)| c);
1741 let scopes = scopes_for_request(claims.as_ref(), &state);
1742 let visible_names = match visible_skill_names_with_remote(&state, &scopes, &client).await {
1743 Ok(v) => v,
1744 Err(err) => return skillati_error_response(err),
1745 };
1746 if !visible_names.contains(&name) {
1747 return skillati_error_response(SkillAtiError::SkillNotFound(name));
1748 }
1749
1750 match client.read_reference(&name, &reference).await {
1751 Ok(content) => (
1752 StatusCode::OK,
1753 Json(serde_json::json!({
1754 "name": name,
1755 "reference": reference,
1756 "content": content,
1757 })),
1758 ),
1759 Err(err) => skillati_error_response(err),
1760 }
1761}
1762
1763fn skillati_error_response(err: SkillAtiError) -> (StatusCode, Json<Value>) {
1764 let status = match &err {
1765 SkillAtiError::NotConfigured
1766 | SkillAtiError::UnsupportedRegistry(_)
1767 | SkillAtiError::MissingCredentials(_)
1768 | SkillAtiError::ProxyUrlRequired => StatusCode::SERVICE_UNAVAILABLE,
1769 SkillAtiError::SkillNotFound(_) | SkillAtiError::PathNotFound { .. } => {
1770 StatusCode::NOT_FOUND
1771 }
1772 SkillAtiError::InvalidPath(_) => StatusCode::BAD_REQUEST,
1773 SkillAtiError::Gcs(_)
1774 | SkillAtiError::ProxyRequest(_)
1775 | SkillAtiError::ProxyResponse(_) => StatusCode::BAD_GATEWAY,
1776 };
1777
1778 (
1779 status,
1780 Json(serde_json::json!({
1781 "error": err.to_string(),
1782 })),
1783 )
1784}
1785
1786async fn auth_middleware(
1794 State(state): State<Arc<ProxyState>>,
1795 mut req: HttpRequest<Body>,
1796 next: Next,
1797) -> Result<Response, StatusCode> {
1798 let path = req.uri().path();
1799
1800 if path == "/health" || path == "/.well-known/jwks.json" {
1802 return Ok(next.run(req).await);
1803 }
1804
1805 let jwt_config = match &state.jwt_config {
1807 Some(c) => c,
1808 None => return Ok(next.run(req).await),
1809 };
1810
1811 let token_owned: String = match req
1816 .headers()
1817 .get("authorization")
1818 .and_then(|v| v.to_str().ok())
1819 {
1820 Some(header) if header.starts_with("Bearer ") => header[7..].to_string(),
1821 _ => return Err(StatusCode::UNAUTHORIZED),
1822 };
1823
1824 match jwt::validate(&token_owned, jwt_config) {
1826 Ok(claims) => {
1827 tracing::debug!(sub = %claims.sub, scopes = %claims.scope, "JWT validated");
1828 req.extensions_mut().insert(BearerToken(token_owned));
1834 req.extensions_mut().insert(claims);
1835 Ok(next.run(req).await)
1836 }
1837 Err(e) => {
1838 tracing::debug!(error = %e, "JWT validation failed");
1839 Err(StatusCode::UNAUTHORIZED)
1840 }
1841 }
1842}
1843
1844#[derive(Debug, Clone)]
1851pub struct BearerToken(pub String);
1852
1853fn max_call_body_bytes() -> usize {
1863 (crate::core::file_manager::MAX_UPLOAD_BYTES as usize)
1864 .saturating_mul(4)
1865 .saturating_div(3)
1866 .saturating_add(8 * 1024)
1867}
1868
1869pub fn build_router(state: Arc<ProxyState>) -> Router {
1870 use axum::extract::DefaultBodyLimit;
1871
1872 Router::new()
1873 .route("/call", post(handle_call))
1874 .route("/help", post(handle_help))
1875 .route("/mcp", post(handle_mcp))
1876 .route("/tools", get(handle_tools_list))
1877 .route("/tools/{name}", get(handle_tool_info))
1878 .route("/skills", get(handle_skills_list))
1879 .route("/skills/resolve", post(handle_skills_resolve))
1880 .route("/skills/bundle", post(handle_skills_bundle_batch))
1881 .route("/skills/{name}", get(handle_skill_detail))
1882 .route("/skills/{name}/bundle", get(handle_skill_bundle))
1883 .route("/skillati/catalog", get(handle_skillati_catalog))
1884 .route("/skillati/{name}", get(handle_skillati_read))
1885 .route("/skillati/{name}/resources", get(handle_skillati_resources))
1886 .route("/skillati/{name}/file", get(handle_skillati_file))
1887 .route("/skillati/{name}/refs", get(handle_skillati_refs))
1888 .route("/skillati/{name}/ref/{reference}", get(handle_skillati_ref))
1889 .route("/health", get(handle_health))
1890 .route("/.well-known/jwks.json", get(handle_jwks))
1891 .layer(DefaultBodyLimit::max(max_call_body_bytes()))
1896 .layer(middleware::from_fn_with_state(
1897 state.clone(),
1898 auth_middleware,
1899 ))
1900 .with_state(state)
1901}
1902
1903pub async fn run(
1907 port: u16,
1908 bind_addr: Option<String>,
1909 ati_dir: PathBuf,
1910 _verbose: bool,
1911 env_keys: bool,
1912) -> Result<(), Box<dyn std::error::Error>> {
1913 let manifests_dir = ati_dir.join("manifests");
1915 let mut registry = ManifestRegistry::load(&manifests_dir)?;
1916 let provider_count = registry.list_providers().len();
1917
1918 let keyring_source;
1920 let keyring = if env_keys {
1921 let kr = Keyring::from_env();
1923 let key_names = kr.key_names();
1924 tracing::info!(
1925 count = key_names.len(),
1926 "loaded API keys from ATI_KEY_* env vars"
1927 );
1928 for name in &key_names {
1929 tracing::debug!(key = %name, "env key loaded");
1930 }
1931 keyring_source = "env-vars (ATI_KEY_*)";
1932 kr
1933 } else {
1934 let keyring_path = ati_dir.join("keyring.enc");
1936 if keyring_path.exists() {
1937 if let Ok(kr) = Keyring::load(&keyring_path) {
1938 keyring_source = "keyring.enc (sealed key)";
1939 kr
1940 } else if let Ok(kr) = Keyring::load_local(&keyring_path, &ati_dir) {
1941 keyring_source = "keyring.enc (persistent key)";
1942 kr
1943 } else {
1944 tracing::warn!("keyring.enc exists but could not be decrypted");
1945 keyring_source = "empty (decryption failed)";
1946 Keyring::empty()
1947 }
1948 } else {
1949 let creds_path = ati_dir.join("credentials");
1950 if creds_path.exists() {
1951 match Keyring::load_credentials(&creds_path) {
1952 Ok(kr) => {
1953 keyring_source = "credentials (plaintext)";
1954 kr
1955 }
1956 Err(e) => {
1957 tracing::warn!(error = %e, "failed to load credentials");
1958 keyring_source = "empty (credentials error)";
1959 Keyring::empty()
1960 }
1961 }
1962 } else {
1963 tracing::warn!("no keyring.enc or credentials found — running without API keys");
1964 tracing::warn!("tools requiring authentication will fail");
1965 keyring_source = "empty (no auth)";
1966 Keyring::empty()
1967 }
1968 }
1969 };
1970
1971 mcp_client::discover_all_mcp_tools(&mut registry, &keyring).await;
1974
1975 let tool_count = registry.list_public_tools().len();
1976
1977 let mcp_providers: Vec<(String, String)> = registry
1979 .list_mcp_providers()
1980 .iter()
1981 .map(|p| (p.name.clone(), p.mcp_transport_type().to_string()))
1982 .collect();
1983 let mcp_count = mcp_providers.len();
1984 let openapi_providers: Vec<String> = registry
1985 .list_openapi_providers()
1986 .iter()
1987 .map(|p| p.name.clone())
1988 .collect();
1989 let openapi_count = openapi_providers.len();
1990
1991 let skills_dir = ati_dir.join("skills");
1993 let skill_registry = SkillRegistry::load(&skills_dir).unwrap_or_else(|e| {
1994 tracing::warn!(error = %e, "failed to load skills");
1995 SkillRegistry::load(std::path::Path::new("/nonexistent-fallback")).unwrap()
1996 });
1997
1998 if let Ok(registry_url) = std::env::var("ATI_SKILL_REGISTRY") {
1999 if registry_url.strip_prefix("gcs://").is_some() {
2000 tracing::info!(
2001 registry = %registry_url,
2002 "SkillATI remote registry configured for lazy reads"
2003 );
2004 } else {
2005 tracing::warn!(url = %registry_url, "SkillATI only supports gcs:// registries");
2006 }
2007 }
2008
2009 let skill_count = skill_registry.skill_count();
2010
2011 let jwt_config = match jwt::config_from_env() {
2013 Ok(config) => config,
2014 Err(e) => {
2015 tracing::warn!(error = %e, "JWT config error");
2016 None
2017 }
2018 };
2019
2020 let auth_status = if jwt_config.is_some() {
2021 "JWT enabled"
2022 } else {
2023 "DISABLED (no JWT keys configured)"
2024 };
2025
2026 let jwks_json = jwt_config.as_ref().and_then(|config| {
2028 config
2029 .public_key_pem
2030 .as_ref()
2031 .and_then(|pem| jwt::public_key_to_jwks(pem, config.algorithm, "ati-proxy-1").ok())
2032 });
2033
2034 let state = Arc::new(ProxyState {
2035 registry,
2036 skill_registry,
2037 keyring,
2038 jwt_config,
2039 jwks_json,
2040 auth_cache: AuthCache::new(),
2041 });
2042
2043 let app = build_router(state);
2044
2045 let addr: SocketAddr = if let Some(ref bind) = bind_addr {
2046 format!("{bind}:{port}").parse()?
2047 } else {
2048 SocketAddr::from(([127, 0, 0, 1], port))
2049 };
2050
2051 tracing::info!(
2052 version = env!("CARGO_PKG_VERSION"),
2053 %addr,
2054 auth = auth_status,
2055 ati_dir = %ati_dir.display(),
2056 tools = tool_count,
2057 providers = provider_count,
2058 mcp = mcp_count,
2059 openapi = openapi_count,
2060 skills = skill_count,
2061 keyring = keyring_source,
2062 "ATI proxy server starting"
2063 );
2064 for (name, transport) in &mcp_providers {
2065 tracing::info!(provider = %name, transport = %transport, "MCP provider");
2066 }
2067 for name in &openapi_providers {
2068 tracing::info!(provider = %name, "OpenAPI provider");
2069 }
2070
2071 let listener = tokio::net::TcpListener::bind(addr).await?;
2072 axum::serve(listener, app).await?;
2073
2074 Ok(())
2075}
2076
2077async fn dispatch_file_manager(
2080 tool_name: &str,
2081 args: &HashMap<String, Value>,
2082 provider: &Provider,
2083 keyring: &Keyring,
2084) -> Result<Value, (StatusCode, String)> {
2085 use crate::core::file_manager::{self, DownloadArgs, FileManagerError, UploadArgs};
2086
2087 let to_resp = |e: FileManagerError| {
2090 let status =
2091 StatusCode::from_u16(e.http_status()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
2092 (status, e.to_string())
2093 };
2094
2095 match tool_name {
2096 "file_manager:download" => {
2097 let parsed = DownloadArgs::from_value(args).map_err(to_resp)?;
2098 let result = file_manager::fetch_bytes(&parsed).await.map_err(to_resp)?;
2099 Ok(file_manager::build_download_response(&result))
2100 }
2101 "file_manager:upload" => {
2102 let parsed = UploadArgs::from_wire(args).map_err(to_resp)?;
2103 file_manager::upload_to_destination(
2104 parsed,
2105 &provider.upload_destinations,
2106 provider.upload_default_destination.as_deref(),
2107 keyring,
2108 )
2109 .await
2110 .map_err(to_resp)
2111 }
2112 other => Err((
2113 StatusCode::NOT_FOUND,
2114 format!("Unknown file_manager tool: '{other}'"),
2115 )),
2116 }
2117}
2118
2119fn write_proxy_audit(
2120 call_req: &CallRequest,
2121 agent_sub: &str,
2122 claims: Option<&TokenClaims>,
2123 duration: std::time::Duration,
2124 error: Option<&str>,
2125) {
2126 let entry = crate::core::audit::AuditEntry {
2127 ts: chrono::Utc::now().to_rfc3339(),
2128 tool: call_req.tool_name.clone(),
2129 args: crate::core::audit::sanitize_args(&call_req.args),
2130 status: if error.is_some() {
2131 crate::core::audit::AuditStatus::Error
2132 } else {
2133 crate::core::audit::AuditStatus::Ok
2134 },
2135 duration_ms: duration.as_millis() as u64,
2136 agent_sub: agent_sub.to_string(),
2137 job_id: claims.and_then(|c| c.job_id.clone()),
2138 sandbox_id: claims.and_then(|c| c.sandbox_id.clone()),
2139 error: error.map(|s| s.to_string()),
2140 exit_code: None,
2141 };
2142 let _ = crate::core::audit::append(&entry);
2143}
2144
2145const HELP_SYSTEM_PROMPT: &str = r#"You are a helpful assistant for an AI agent that uses external tools via the `ati` CLI.
2148
2149## Available Tools
2150{tools}
2151
2152{skills_section}
2153
2154Answer the agent's question naturally, like a knowledgeable colleague would. Keep it short but useful:
2155
2156- Explain which tools to use and why, with `ati run` commands showing realistic parameter values
2157- If multiple steps are needed, walk through them briefly in order
2158- Mention important gotchas or parameter choices that matter
2159- If skills are relevant, tell the agent to load them using the Skill tool (e.g., `skill: "research-financial-data"`)
2160
2161Keep your answer concise — a few short paragraphs with embedded code blocks. Only recommend tools from the list above."#;
2162
2163async fn build_remote_skillati_section(keyring: &Keyring, query: &str, limit: usize) -> String {
2164 let client = match SkillAtiClient::from_env(keyring) {
2165 Ok(Some(client)) => client,
2166 Ok(None) => return String::new(),
2167 Err(err) => {
2168 tracing::warn!(error = %err, "failed to initialize SkillATI catalog for proxy help");
2169 return String::new();
2170 }
2171 };
2172
2173 let catalog = match client.catalog().await {
2174 Ok(catalog) => catalog,
2175 Err(err) => {
2176 tracing::warn!(error = %err, "failed to load SkillATI catalog for proxy help");
2177 return String::new();
2178 }
2179 };
2180
2181 let matched = SkillAtiClient::filter_catalog(&catalog, query, limit);
2182 if matched.is_empty() {
2183 return String::new();
2184 }
2185
2186 render_remote_skillati_section(&matched, catalog.len())
2187}
2188
2189fn render_remote_skillati_section(skills: &[RemoteSkillMeta], total_catalog: usize) -> String {
2190 let mut section = String::from("## Remote Skills Available Via SkillATI\n\n");
2191 section.push_str(
2192 "These skills are available. Load them using the Skill tool (e.g., `skill: \"skill-name\"`).\n\n",
2193 );
2194
2195 for skill in skills {
2196 section.push_str(&format!("- **{}**: {}\n", skill.name, skill.description));
2197 }
2198
2199 if total_catalog > skills.len() {
2200 section.push_str(&format!(
2201 "\nOnly the most relevant {} remote skills are shown here.\n",
2202 skills.len()
2203 ));
2204 }
2205
2206 section
2207}
2208
2209fn merge_help_skill_sections(sections: &[String]) -> String {
2210 sections
2211 .iter()
2212 .filter_map(|section| {
2213 let trimmed = section.trim();
2214 if trimmed.is_empty() {
2215 None
2216 } else {
2217 Some(trimmed.to_string())
2218 }
2219 })
2220 .collect::<Vec<_>>()
2221 .join("\n\n")
2222}
2223
2224fn build_tool_context(
2225 tools: &[(
2226 &crate::core::manifest::Provider,
2227 &crate::core::manifest::Tool,
2228 )],
2229) -> String {
2230 let mut summaries = Vec::new();
2231 for (provider, tool) in tools {
2232 let mut summary = if let Some(cat) = &provider.category {
2233 format!(
2234 "- **{}** (provider: {}, category: {}): {}",
2235 tool.name, provider.name, cat, tool.description
2236 )
2237 } else {
2238 format!(
2239 "- **{}** (provider: {}): {}",
2240 tool.name, provider.name, tool.description
2241 )
2242 };
2243 if !tool.tags.is_empty() {
2244 summary.push_str(&format!("\n Tags: {}", tool.tags.join(", ")));
2245 }
2246 if provider.is_cli() && tool.input_schema.is_none() {
2248 let cmd = provider.cli_command.as_deref().unwrap_or("?");
2249 summary.push_str(&format!(
2250 "\n Usage: `ati run {} -- <args>` (passthrough to `{}`)",
2251 tool.name, cmd
2252 ));
2253 } else if let Some(schema) = &tool.input_schema {
2254 if let Some(props) = schema.get("properties") {
2255 if let Some(obj) = props.as_object() {
2256 let params: Vec<String> = obj
2257 .iter()
2258 .filter(|(_, v)| {
2259 v.get("x-ati-param-location").is_none()
2260 || v.get("description").is_some()
2261 })
2262 .map(|(k, v)| {
2263 let type_str =
2264 v.get("type").and_then(|t| t.as_str()).unwrap_or("string");
2265 let desc = v.get("description").and_then(|d| d.as_str()).unwrap_or("");
2266 format!(" --{k} ({type_str}): {desc}")
2267 })
2268 .collect();
2269 if !params.is_empty() {
2270 summary.push_str("\n Parameters:\n");
2271 summary.push_str(¶ms.join("\n"));
2272 }
2273 }
2274 }
2275 }
2276 summaries.push(summary);
2277 }
2278 summaries.join("\n\n")
2279}
2280
2281fn build_scoped_prompt(
2285 scope_name: &str,
2286 visible_tools: &[(&Provider, &Tool)],
2287 skills_section: &str,
2288) -> Option<String> {
2289 if let Some((provider, tool)) = visible_tools
2291 .iter()
2292 .find(|(_, tool)| tool.name == scope_name)
2293 {
2294 let mut details = format!(
2295 "**Name**: `{}`\n**Provider**: {} (handler: {})\n**Description**: {}\n",
2296 tool.name, provider.name, provider.handler, tool.description
2297 );
2298 if let Some(cat) = &provider.category {
2299 details.push_str(&format!("**Category**: {}\n", cat));
2300 }
2301 if provider.is_cli() {
2302 let cmd = provider.cli_command.as_deref().unwrap_or("?");
2303 details.push_str(&format!(
2304 "\n**Usage**: `ati run {} -- <args>` (passthrough to `{}`)\n",
2305 tool.name, cmd
2306 ));
2307 } else if let Some(schema) = &tool.input_schema {
2308 if let Some(props) = schema.get("properties").and_then(|p| p.as_object()) {
2309 let required: Vec<String> = schema
2310 .get("required")
2311 .and_then(|r| r.as_array())
2312 .map(|arr| {
2313 arr.iter()
2314 .filter_map(|v| v.as_str().map(|s| s.to_string()))
2315 .collect()
2316 })
2317 .unwrap_or_default();
2318 details.push_str("\n**Parameters**:\n");
2319 for (key, val) in props {
2320 let type_str = val.get("type").and_then(|t| t.as_str()).unwrap_or("string");
2321 let desc = val
2322 .get("description")
2323 .and_then(|d| d.as_str())
2324 .unwrap_or("");
2325 let req = if required.contains(key) {
2326 " **(required)**"
2327 } else {
2328 ""
2329 };
2330 details.push_str(&format!("- `--{key}` ({type_str}{req}): {desc}\n"));
2331 }
2332 }
2333 }
2334
2335 let prompt = format!(
2336 "You are an expert assistant for the `{}` tool, accessed via the `ati` CLI.\n\n\
2337 ## Tool Details\n{}\n\n{}\n\n\
2338 Answer the agent's question about this specific tool. Provide exact commands, explain flags and options, and give practical examples. Be concise and actionable.",
2339 tool.name, details, skills_section
2340 );
2341 return Some(prompt);
2342 }
2343
2344 let tools: Vec<(&Provider, &Tool)> = visible_tools
2346 .iter()
2347 .copied()
2348 .filter(|(provider, _)| provider.name == scope_name)
2349 .collect();
2350 if !tools.is_empty() {
2351 let tools_context = build_tool_context(&tools);
2352 let prompt = format!(
2353 "You are an expert assistant for the `{}` provider's tools, accessed via the `ati` CLI.\n\n\
2354 ## Tools in provider `{}`\n{}\n\n{}\n\n\
2355 Answer the agent's question about these tools. Provide exact `ati run` commands, explain parameters, and give practical examples. Be concise and actionable.",
2356 scope_name, scope_name, tools_context, skills_section
2357 );
2358 return Some(prompt);
2359 }
2360
2361 None
2362}