1use axum::{
8 body::Body,
9 extract::State,
10 http::{Request as HttpRequest, StatusCode},
11 middleware::{self, Next},
12 response::{IntoResponse, Response},
13 routing::{get, post},
14 Json, Router,
15};
16use serde::{Deserialize, Serialize};
17use serde_json::Value;
18use std::collections::HashMap;
19use std::net::SocketAddr;
20use std::path::PathBuf;
21use std::sync::Arc;
22
23use crate::core::auth_generator::{AuthCache, GenContext};
24use crate::core::http;
25use crate::core::jwt::{self, JwtConfig, TokenClaims};
26use crate::core::keyring::Keyring;
27use crate::core::manifest::ManifestRegistry;
28use crate::core::mcp_client;
29use crate::core::response;
30use crate::core::scope::ScopeConfig;
31use crate::core::skill::{self, SkillRegistry};
32use crate::core::xai;
33
34pub struct ProxyState {
36 pub registry: ManifestRegistry,
37 pub skill_registry: SkillRegistry,
38 pub keyring: Keyring,
39 pub jwt_config: Option<JwtConfig>,
41 pub jwks_json: Option<Value>,
43 pub auth_cache: AuthCache,
45}
46
47#[derive(Debug, Deserialize)]
50pub struct CallRequest {
51 pub tool_name: String,
52 #[serde(default = "default_args")]
56 pub args: Value,
57 #[serde(default)]
60 pub raw_args: Option<Vec<String>>,
61}
62
63fn default_args() -> Value {
64 Value::Object(serde_json::Map::new())
65}
66
67impl CallRequest {
68 fn args_as_map(&self) -> HashMap<String, Value> {
72 match &self.args {
73 Value::Object(map) => map.iter().map(|(k, v)| (k.clone(), v.clone())).collect(),
74 _ => HashMap::new(),
75 }
76 }
77
78 fn args_as_positional(&self) -> Vec<String> {
81 if let Some(ref raw) = self.raw_args {
83 return raw.clone();
84 }
85 match &self.args {
86 Value::Array(arr) => arr
88 .iter()
89 .map(|v| match v {
90 Value::String(s) => s.clone(),
91 other => other.to_string(),
92 })
93 .collect(),
94 Value::String(s) => s.split_whitespace().map(String::from).collect(),
96 Value::Object(map) => {
98 if let Some(Value::Array(pos)) = map.get("_positional") {
99 return pos
100 .iter()
101 .map(|v| match v {
102 Value::String(s) => s.clone(),
103 other => other.to_string(),
104 })
105 .collect();
106 }
107 let mut result = Vec::new();
109 for (k, v) in map {
110 result.push(format!("--{k}"));
111 match v {
112 Value::String(s) => result.push(s.clone()),
113 Value::Bool(true) => {} other => result.push(other.to_string()),
115 }
116 }
117 result
118 }
119 _ => Vec::new(),
120 }
121 }
122}
123
124#[derive(Debug, Serialize)]
125pub struct CallResponse {
126 pub result: Value,
127 #[serde(skip_serializing_if = "Option::is_none")]
128 pub error: Option<String>,
129}
130
131#[derive(Debug, Deserialize)]
132pub struct HelpRequest {
133 pub query: String,
134 #[serde(default)]
135 pub tool: Option<String>,
136}
137
138#[derive(Debug, Serialize)]
139pub struct HelpResponse {
140 pub content: String,
141 #[serde(skip_serializing_if = "Option::is_none")]
142 pub error: Option<String>,
143}
144
145#[derive(Debug, Serialize)]
146pub struct HealthResponse {
147 pub status: String,
148 pub version: String,
149 pub tools: usize,
150 pub providers: usize,
151 pub skills: usize,
152 pub auth: String,
153}
154
155#[derive(Debug, Deserialize)]
158pub struct SkillsQuery {
159 #[serde(default)]
160 pub category: Option<String>,
161 #[serde(default)]
162 pub provider: Option<String>,
163 #[serde(default)]
164 pub tool: Option<String>,
165 #[serde(default)]
166 pub search: Option<String>,
167}
168
169#[derive(Debug, Deserialize)]
170pub struct SkillDetailQuery {
171 #[serde(default)]
172 pub meta: Option<bool>,
173 #[serde(default)]
174 pub refs: Option<bool>,
175}
176
177#[derive(Debug, Deserialize)]
178pub struct SkillResolveRequest {
179 pub scopes: Vec<String>,
180 #[serde(default)]
182 pub include_content: bool,
183}
184
185#[derive(Debug, Deserialize)]
186pub struct SkillBundleBatchRequest {
187 pub names: Vec<String>,
188}
189
190#[derive(Debug, Deserialize)]
193pub struct ToolsQuery {
194 #[serde(default)]
195 pub provider: Option<String>,
196 #[serde(default)]
197 pub search: Option<String>,
198}
199
200async fn handle_call(
203 State(state): State<Arc<ProxyState>>,
204 req: HttpRequest<Body>,
205) -> impl IntoResponse {
206 let claims = req.extensions().get::<TokenClaims>().cloned();
208
209 let body_bytes = match axum::body::to_bytes(req.into_body(), 10 * 1024 * 1024).await {
211 Ok(b) => b,
212 Err(e) => {
213 return (
214 StatusCode::BAD_REQUEST,
215 Json(CallResponse {
216 result: Value::Null,
217 error: Some(format!("Failed to read request body: {e}")),
218 }),
219 );
220 }
221 };
222
223 let call_req: CallRequest = match serde_json::from_slice(&body_bytes) {
224 Ok(r) => r,
225 Err(e) => {
226 return (
227 StatusCode::UNPROCESSABLE_ENTITY,
228 Json(CallResponse {
229 result: Value::Null,
230 error: Some(format!("Invalid request: {e}")),
231 }),
232 );
233 }
234 };
235
236 tracing::debug!(
237 tool = %call_req.tool_name,
238 args = ?call_req.args,
239 "POST /call"
240 );
241
242 let (provider, tool) = match state.registry.get_tool(&call_req.tool_name) {
245 Some(pt) => pt,
246 None => {
247 let mut resolved = None;
251 for (idx, _) in call_req.tool_name.match_indices('_') {
252 let candidate = format!(
253 "{}:{}",
254 &call_req.tool_name[..idx],
255 &call_req.tool_name[idx + 1..]
256 );
257 if let Some(pt) = state.registry.get_tool(&candidate) {
258 tracing::debug!(
259 original = %call_req.tool_name,
260 resolved = %candidate,
261 "resolved underscore tool name to colon format"
262 );
263 resolved = Some(pt);
264 break;
265 }
266 }
267
268 match resolved {
269 Some(pt) => pt,
270 None => {
271 return (
272 StatusCode::NOT_FOUND,
273 Json(CallResponse {
274 result: Value::Null,
275 error: Some(format!("Unknown tool: '{}'", call_req.tool_name)),
276 }),
277 );
278 }
279 }
280 }
281 };
282
283 if let Some(tool_scope) = &tool.scope {
285 let scopes = match &claims {
286 Some(c) => ScopeConfig::from_jwt(c),
287 None if state.jwt_config.is_none() => ScopeConfig::unrestricted(), None => {
289 return (
290 StatusCode::FORBIDDEN,
291 Json(CallResponse {
292 result: Value::Null,
293 error: Some("Authentication required — no JWT provided".into()),
294 }),
295 );
296 }
297 };
298
299 let underscore_scope = if let Some(after_prefix) = tool_scope.strip_prefix("tool:") {
302 format!("tool:{}", after_prefix.replacen(':', "_", 1))
304 } else {
305 String::new()
306 };
307
308 let allowed = scopes.is_allowed(tool_scope)
309 || (!underscore_scope.is_empty() && scopes.is_allowed(&underscore_scope));
310
311 if !allowed {
312 return (
313 StatusCode::FORBIDDEN,
314 Json(CallResponse {
315 result: Value::Null,
316 error: Some(format!(
317 "Access denied: '{}' is not in your scopes",
318 tool.name
319 )),
320 }),
321 );
322 }
323 }
324
325 {
327 let scopes = match &claims {
328 Some(c) => ScopeConfig::from_jwt(c),
329 None => ScopeConfig::unrestricted(),
330 };
331 if let Some(ref rate_config) = scopes.rate_config {
332 if let Err(e) = crate::core::rate::check_and_record(&call_req.tool_name, rate_config) {
333 return (
334 StatusCode::TOO_MANY_REQUESTS,
335 Json(CallResponse {
336 result: Value::Null,
337 error: Some(format!("{e}")),
338 }),
339 );
340 }
341 }
342 }
343
344 let gen_ctx = GenContext {
346 jwt_sub: claims
347 .as_ref()
348 .map(|c| c.sub.clone())
349 .unwrap_or_else(|| "dev".into()),
350 jwt_scope: claims
351 .as_ref()
352 .map(|c| c.scope.clone())
353 .unwrap_or_else(|| "*".into()),
354 tool_name: call_req.tool_name.clone(),
355 timestamp: crate::core::jwt::now_secs(),
356 };
357
358 let agent_sub = claims.as_ref().map(|c| c.sub.clone()).unwrap_or_default();
360 let start = std::time::Instant::now();
361
362 let response = match provider.handler.as_str() {
363 "mcp" => {
364 let args_map = call_req.args_as_map();
365 match mcp_client::execute_with_gen(
366 provider,
367 &call_req.tool_name,
368 &args_map,
369 &state.keyring,
370 Some(&gen_ctx),
371 Some(&state.auth_cache),
372 )
373 .await
374 {
375 Ok(result) => (
376 StatusCode::OK,
377 Json(CallResponse {
378 result,
379 error: None,
380 }),
381 ),
382 Err(e) => (
383 StatusCode::BAD_GATEWAY,
384 Json(CallResponse {
385 result: Value::Null,
386 error: Some(format!("MCP error: {e}")),
387 }),
388 ),
389 }
390 }
391 "cli" => {
392 let positional = call_req.args_as_positional();
393 match crate::core::cli_executor::execute_with_gen(
394 provider,
395 &positional,
396 &state.keyring,
397 Some(&gen_ctx),
398 Some(&state.auth_cache),
399 )
400 .await
401 {
402 Ok(result) => (
403 StatusCode::OK,
404 Json(CallResponse {
405 result,
406 error: None,
407 }),
408 ),
409 Err(e) => (
410 StatusCode::BAD_GATEWAY,
411 Json(CallResponse {
412 result: Value::Null,
413 error: Some(format!("CLI error: {e}")),
414 }),
415 ),
416 }
417 }
418 _ => {
419 let args_map = call_req.args_as_map();
420 let raw_response = match match provider.handler.as_str() {
421 "xai" => xai::execute_xai_tool(provider, tool, &args_map, &state.keyring).await,
422 _ => {
423 http::execute_tool_with_gen(
424 provider,
425 tool,
426 &args_map,
427 &state.keyring,
428 Some(&gen_ctx),
429 Some(&state.auth_cache),
430 )
431 .await
432 }
433 } {
434 Ok(resp) => resp,
435 Err(e) => {
436 let duration = start.elapsed();
437 write_proxy_audit(&call_req, &agent_sub, duration, Some(&e.to_string()));
438 return (
439 StatusCode::BAD_GATEWAY,
440 Json(CallResponse {
441 result: Value::Null,
442 error: Some(format!("Upstream API error: {e}")),
443 }),
444 );
445 }
446 };
447
448 let processed = match response::process_response(&raw_response, tool.response.as_ref())
449 {
450 Ok(p) => p,
451 Err(e) => {
452 let duration = start.elapsed();
453 write_proxy_audit(&call_req, &agent_sub, duration, Some(&e.to_string()));
454 return (
455 StatusCode::INTERNAL_SERVER_ERROR,
456 Json(CallResponse {
457 result: raw_response,
458 error: Some(format!("Response processing error: {e}")),
459 }),
460 );
461 }
462 };
463
464 (
465 StatusCode::OK,
466 Json(CallResponse {
467 result: processed,
468 error: None,
469 }),
470 )
471 }
472 };
473
474 let duration = start.elapsed();
475 let error_msg = response.1.error.as_deref();
476 write_proxy_audit(&call_req, &agent_sub, duration, error_msg);
477
478 response
479}
480
481async fn handle_help(
482 State(state): State<Arc<ProxyState>>,
483 Json(req): Json<HelpRequest>,
484) -> impl IntoResponse {
485 tracing::debug!(query = %req.query, tool = ?req.tool, "POST /help");
486
487 let (llm_provider, llm_tool) = match state.registry.get_tool("_chat_completion") {
488 Some(pt) => pt,
489 None => {
490 return (
491 StatusCode::SERVICE_UNAVAILABLE,
492 Json(HelpResponse {
493 content: String::new(),
494 error: Some("No _llm.toml manifest found. Proxy help requires a configured LLM provider.".into()),
495 }),
496 );
497 }
498 };
499
500 let api_key = match llm_provider
501 .auth_key_name
502 .as_deref()
503 .and_then(|k| state.keyring.get(k))
504 {
505 Some(key) => key.to_string(),
506 None => {
507 return (
508 StatusCode::SERVICE_UNAVAILABLE,
509 Json(HelpResponse {
510 content: String::new(),
511 error: Some("LLM API key not found in keyring".into()),
512 }),
513 );
514 }
515 };
516
517 let scopes = ScopeConfig::unrestricted();
518 let resolved_skills = skill::resolve_skills(&state.skill_registry, &state.registry, &scopes);
519 let skills_section = if resolved_skills.is_empty() {
520 String::new()
521 } else {
522 format!(
523 "## Available Skills (methodology guides)\n{}",
524 skill::build_skill_context(&resolved_skills)
525 )
526 };
527
528 let system_prompt = if let Some(ref tool_name) = req.tool {
530 match build_scoped_prompt(tool_name, &state.registry, &skills_section) {
532 Some(prompt) => prompt,
533 None => {
534 tracing::debug!(scope = %tool_name, "scope not found, falling back to unscoped");
536 let all_tools = state.registry.list_public_tools();
537 let tools_context = build_tool_context(&all_tools);
538 HELP_SYSTEM_PROMPT
539 .replace("{tools}", &tools_context)
540 .replace("{skills_section}", &skills_section)
541 }
542 }
543 } else {
544 let all_tools = state.registry.list_public_tools();
545 let tools_context = build_tool_context(&all_tools);
546 HELP_SYSTEM_PROMPT
547 .replace("{tools}", &tools_context)
548 .replace("{skills_section}", &skills_section)
549 };
550
551 let request_body = serde_json::json!({
552 "model": "zai-glm-4.7",
553 "messages": [
554 {"role": "system", "content": system_prompt},
555 {"role": "user", "content": req.query}
556 ],
557 "max_completion_tokens": 1536,
558 "temperature": 0.3
559 });
560
561 let client = reqwest::Client::new();
562 let url = format!(
563 "{}{}",
564 llm_provider.base_url.trim_end_matches('/'),
565 llm_tool.endpoint
566 );
567
568 let response = match client
569 .post(&url)
570 .bearer_auth(&api_key)
571 .json(&request_body)
572 .send()
573 .await
574 {
575 Ok(r) => r,
576 Err(e) => {
577 return (
578 StatusCode::BAD_GATEWAY,
579 Json(HelpResponse {
580 content: String::new(),
581 error: Some(format!("LLM request failed: {e}")),
582 }),
583 );
584 }
585 };
586
587 if !response.status().is_success() {
588 let status = response.status();
589 let body = response.text().await.unwrap_or_default();
590 return (
591 StatusCode::BAD_GATEWAY,
592 Json(HelpResponse {
593 content: String::new(),
594 error: Some(format!("LLM API error ({status}): {body}")),
595 }),
596 );
597 }
598
599 let body: Value = match response.json().await {
600 Ok(b) => b,
601 Err(e) => {
602 return (
603 StatusCode::INTERNAL_SERVER_ERROR,
604 Json(HelpResponse {
605 content: String::new(),
606 error: Some(format!("Failed to parse LLM response: {e}")),
607 }),
608 );
609 }
610 };
611
612 let content = body
613 .pointer("/choices/0/message/content")
614 .and_then(|c| c.as_str())
615 .unwrap_or("No response from LLM")
616 .to_string();
617
618 (
619 StatusCode::OK,
620 Json(HelpResponse {
621 content,
622 error: None,
623 }),
624 )
625}
626
627async fn handle_health(State(state): State<Arc<ProxyState>>) -> impl IntoResponse {
628 let auth = if state.jwt_config.is_some() {
629 "jwt"
630 } else {
631 "disabled"
632 };
633
634 Json(HealthResponse {
635 status: "ok".into(),
636 version: env!("CARGO_PKG_VERSION").into(),
637 tools: state.registry.list_public_tools().len(),
638 providers: state.registry.list_providers().len(),
639 skills: state.skill_registry.skill_count(),
640 auth: auth.into(),
641 })
642}
643
644async fn handle_jwks(State(state): State<Arc<ProxyState>>) -> impl IntoResponse {
646 match &state.jwks_json {
647 Some(jwks) => (StatusCode::OK, Json(jwks.clone())),
648 None => (
649 StatusCode::NOT_FOUND,
650 Json(serde_json::json!({"error": "JWKS not configured"})),
651 ),
652 }
653}
654
655async fn handle_mcp(
660 State(state): State<Arc<ProxyState>>,
661 Json(msg): Json<Value>,
662) -> impl IntoResponse {
663 let method = msg.get("method").and_then(|m| m.as_str()).unwrap_or("");
664 let id = msg.get("id").cloned();
665
666 tracing::debug!(%method, "POST /mcp");
667
668 match method {
669 "initialize" => {
670 let result = serde_json::json!({
671 "protocolVersion": "2025-03-26",
672 "capabilities": {
673 "tools": { "listChanged": false }
674 },
675 "serverInfo": {
676 "name": "ati-proxy",
677 "version": env!("CARGO_PKG_VERSION")
678 }
679 });
680 jsonrpc_success(id, result)
681 }
682
683 "notifications/initialized" => (StatusCode::ACCEPTED, Json(Value::Null)),
684
685 "tools/list" => {
686 let all_tools = state.registry.list_public_tools();
687 let mcp_tools: Vec<Value> = all_tools
688 .iter()
689 .map(|(_provider, tool)| {
690 serde_json::json!({
691 "name": tool.name,
692 "description": tool.description,
693 "inputSchema": tool.input_schema.clone().unwrap_or(serde_json::json!({
694 "type": "object",
695 "properties": {}
696 }))
697 })
698 })
699 .collect();
700
701 let result = serde_json::json!({
702 "tools": mcp_tools,
703 });
704 jsonrpc_success(id, result)
705 }
706
707 "tools/call" => {
708 let params = msg.get("params").cloned().unwrap_or(Value::Null);
709 let tool_name = params.get("name").and_then(|n| n.as_str()).unwrap_or("");
710 let arguments: HashMap<String, Value> = params
711 .get("arguments")
712 .and_then(|a| serde_json::from_value(a.clone()).ok())
713 .unwrap_or_default();
714
715 if tool_name.is_empty() {
716 return jsonrpc_error(id, -32602, "Missing tool name in params.name");
717 }
718
719 let (provider, _tool) = match state.registry.get_tool(tool_name) {
720 Some(pt) => pt,
721 None => {
722 return jsonrpc_error(id, -32602, &format!("Unknown tool: '{tool_name}'"));
723 }
724 };
725
726 tracing::debug!(%tool_name, provider = %provider.name, "MCP tools/call");
727
728 let mcp_gen_ctx = GenContext {
729 jwt_sub: "dev".into(),
730 jwt_scope: "*".into(),
731 tool_name: tool_name.to_string(),
732 timestamp: crate::core::jwt::now_secs(),
733 };
734
735 let result = if provider.is_mcp() {
736 mcp_client::execute_with_gen(
737 provider,
738 tool_name,
739 &arguments,
740 &state.keyring,
741 Some(&mcp_gen_ctx),
742 Some(&state.auth_cache),
743 )
744 .await
745 } else if provider.is_cli() {
746 let raw: Vec<String> = arguments
748 .iter()
749 .flat_map(|(k, v)| {
750 let val = match v {
751 Value::String(s) => s.clone(),
752 other => other.to_string(),
753 };
754 vec![format!("--{k}"), val]
755 })
756 .collect();
757 crate::core::cli_executor::execute_with_gen(
758 provider,
759 &raw,
760 &state.keyring,
761 Some(&mcp_gen_ctx),
762 Some(&state.auth_cache),
763 )
764 .await
765 .map_err(|e| mcp_client::McpError::Transport(e.to_string()))
766 } else {
767 match match provider.handler.as_str() {
768 "xai" => {
769 xai::execute_xai_tool(provider, _tool, &arguments, &state.keyring).await
770 }
771 _ => {
772 http::execute_tool_with_gen(
773 provider,
774 _tool,
775 &arguments,
776 &state.keyring,
777 Some(&mcp_gen_ctx),
778 Some(&state.auth_cache),
779 )
780 .await
781 }
782 } {
783 Ok(val) => Ok(val),
784 Err(e) => Err(mcp_client::McpError::Transport(e.to_string())),
785 }
786 };
787
788 match result {
789 Ok(value) => {
790 let text = match &value {
791 Value::String(s) => s.clone(),
792 other => serde_json::to_string_pretty(other).unwrap_or_default(),
793 };
794 let mcp_result = serde_json::json!({
795 "content": [{"type": "text", "text": text}],
796 "isError": false,
797 });
798 jsonrpc_success(id, mcp_result)
799 }
800 Err(e) => {
801 let mcp_result = serde_json::json!({
802 "content": [{"type": "text", "text": format!("Error: {e}")}],
803 "isError": true,
804 });
805 jsonrpc_success(id, mcp_result)
806 }
807 }
808 }
809
810 _ => jsonrpc_error(id, -32601, &format!("Method not found: '{method}'")),
811 }
812}
813
814fn jsonrpc_success(id: Option<Value>, result: Value) -> (StatusCode, Json<Value>) {
815 (
816 StatusCode::OK,
817 Json(serde_json::json!({
818 "jsonrpc": "2.0",
819 "id": id,
820 "result": result,
821 })),
822 )
823}
824
825fn jsonrpc_error(id: Option<Value>, code: i64, message: &str) -> (StatusCode, Json<Value>) {
826 (
827 StatusCode::OK,
828 Json(serde_json::json!({
829 "jsonrpc": "2.0",
830 "id": id,
831 "error": {
832 "code": code,
833 "message": message,
834 }
835 })),
836 )
837}
838
839async fn handle_tools_list(
845 State(state): State<Arc<ProxyState>>,
846 axum::extract::Query(query): axum::extract::Query<ToolsQuery>,
847) -> impl IntoResponse {
848 tracing::debug!(
849 provider = ?query.provider,
850 search = ?query.search,
851 "GET /tools"
852 );
853
854 let all_tools = state.registry.list_public_tools();
855
856 let tools: Vec<Value> = all_tools
857 .iter()
858 .filter(|(provider, tool)| {
859 if let Some(ref p) = query.provider {
860 if provider.name != *p {
861 return false;
862 }
863 }
864 if let Some(ref q) = query.search {
865 let q = q.to_lowercase();
866 let name_match = tool.name.to_lowercase().contains(&q);
867 let desc_match = tool.description.to_lowercase().contains(&q);
868 let tag_match = tool.tags.iter().any(|t| t.to_lowercase().contains(&q));
869 if !name_match && !desc_match && !tag_match {
870 return false;
871 }
872 }
873 true
874 })
875 .map(|(provider, tool)| {
876 serde_json::json!({
877 "name": tool.name,
878 "description": tool.description,
879 "provider": provider.name,
880 "method": format!("{:?}", tool.method),
881 "tags": tool.tags,
882 "input_schema": tool.input_schema,
883 })
884 })
885 .collect();
886
887 (StatusCode::OK, Json(Value::Array(tools)))
888}
889
890async fn handle_tool_info(
892 State(state): State<Arc<ProxyState>>,
893 axum::extract::Path(name): axum::extract::Path<String>,
894) -> impl IntoResponse {
895 tracing::debug!(tool = %name, "GET /tools/:name");
896
897 match state.registry.get_tool(&name) {
898 Some((provider, tool)) => (
899 StatusCode::OK,
900 Json(serde_json::json!({
901 "name": tool.name,
902 "description": tool.description,
903 "provider": provider.name,
904 "method": format!("{:?}", tool.method),
905 "endpoint": tool.endpoint,
906 "tags": tool.tags,
907 "hint": tool.hint,
908 "input_schema": tool.input_schema,
909 "scope": tool.scope,
910 })),
911 ),
912 None => (
913 StatusCode::NOT_FOUND,
914 Json(serde_json::json!({"error": format!("Tool '{name}' not found")})),
915 ),
916 }
917}
918
919async fn handle_skills_list(
924 State(state): State<Arc<ProxyState>>,
925 axum::extract::Query(query): axum::extract::Query<SkillsQuery>,
926) -> impl IntoResponse {
927 tracing::debug!(
928 category = ?query.category,
929 provider = ?query.provider,
930 tool = ?query.tool,
931 search = ?query.search,
932 "GET /skills"
933 );
934
935 let skills: Vec<&skill::SkillMeta> = if let Some(search_query) = &query.search {
936 state.skill_registry.search(search_query)
937 } else if let Some(cat) = &query.category {
938 state.skill_registry.skills_for_category(cat)
939 } else if let Some(prov) = &query.provider {
940 state.skill_registry.skills_for_provider(prov)
941 } else if let Some(t) = &query.tool {
942 state.skill_registry.skills_for_tool(t)
943 } else {
944 state.skill_registry.list_skills().iter().collect()
945 };
946
947 let json: Vec<Value> = skills
948 .iter()
949 .map(|s| {
950 serde_json::json!({
951 "name": s.name,
952 "version": s.version,
953 "description": s.description,
954 "tools": s.tools,
955 "providers": s.providers,
956 "categories": s.categories,
957 "hint": s.hint,
958 })
959 })
960 .collect();
961
962 (StatusCode::OK, Json(Value::Array(json)))
963}
964
965async fn handle_skill_detail(
966 State(state): State<Arc<ProxyState>>,
967 axum::extract::Path(name): axum::extract::Path<String>,
968 axum::extract::Query(query): axum::extract::Query<SkillDetailQuery>,
969) -> impl IntoResponse {
970 tracing::debug!(%name, meta = ?query.meta, refs = ?query.refs, "GET /skills/:name");
971
972 let skill_meta = match state.skill_registry.get_skill(&name) {
973 Some(s) => s,
974 None => {
975 return (
976 StatusCode::NOT_FOUND,
977 Json(serde_json::json!({"error": format!("Skill '{name}' not found")})),
978 );
979 }
980 };
981
982 if query.meta.unwrap_or(false) {
983 return (
984 StatusCode::OK,
985 Json(serde_json::json!({
986 "name": skill_meta.name,
987 "version": skill_meta.version,
988 "description": skill_meta.description,
989 "author": skill_meta.author,
990 "tools": skill_meta.tools,
991 "providers": skill_meta.providers,
992 "categories": skill_meta.categories,
993 "keywords": skill_meta.keywords,
994 "hint": skill_meta.hint,
995 "depends_on": skill_meta.depends_on,
996 "suggests": skill_meta.suggests,
997 "license": skill_meta.license,
998 "compatibility": skill_meta.compatibility,
999 "allowed_tools": skill_meta.allowed_tools,
1000 "format": skill_meta.format,
1001 })),
1002 );
1003 }
1004
1005 let content = match state.skill_registry.read_content(&name) {
1006 Ok(c) => c,
1007 Err(e) => {
1008 return (
1009 StatusCode::INTERNAL_SERVER_ERROR,
1010 Json(serde_json::json!({"error": format!("Failed to read skill: {e}")})),
1011 );
1012 }
1013 };
1014
1015 let mut response = serde_json::json!({
1016 "name": skill_meta.name,
1017 "version": skill_meta.version,
1018 "description": skill_meta.description,
1019 "content": content,
1020 });
1021
1022 if query.refs.unwrap_or(false) {
1023 if let Ok(refs) = state.skill_registry.list_references(&name) {
1024 response["references"] = serde_json::json!(refs);
1025 }
1026 }
1027
1028 (StatusCode::OK, Json(response))
1029}
1030
1031async fn handle_skill_bundle(
1035 State(state): State<Arc<ProxyState>>,
1036 axum::extract::Path(name): axum::extract::Path<String>,
1037) -> impl IntoResponse {
1038 tracing::debug!(skill = %name, "GET /skills/:name/bundle");
1039
1040 let files = match state.skill_registry.bundle_files(&name) {
1041 Ok(f) => f,
1042 Err(_) => {
1043 return (
1044 StatusCode::NOT_FOUND,
1045 Json(serde_json::json!({"error": format!("Skill '{name}' not found")})),
1046 );
1047 }
1048 };
1049
1050 let mut file_map = serde_json::Map::new();
1052 for (path, data) in &files {
1053 match std::str::from_utf8(data) {
1054 Ok(text) => {
1055 file_map.insert(path.clone(), Value::String(text.to_string()));
1056 }
1057 Err(_) => {
1058 use base64::Engine;
1060 let encoded = base64::engine::general_purpose::STANDARD.encode(data);
1061 file_map.insert(path.clone(), serde_json::json!({"base64": encoded}));
1062 }
1063 }
1064 }
1065
1066 (
1067 StatusCode::OK,
1068 Json(serde_json::json!({
1069 "name": name,
1070 "files": file_map,
1071 })),
1072 )
1073}
1074
1075async fn handle_skills_bundle_batch(
1079 State(state): State<Arc<ProxyState>>,
1080 Json(req): Json<SkillBundleBatchRequest>,
1081) -> impl IntoResponse {
1082 const MAX_BATCH: usize = 50;
1083 if req.names.len() > MAX_BATCH {
1084 return (
1085 StatusCode::BAD_REQUEST,
1086 Json(
1087 serde_json::json!({"error": format!("batch size {} exceeds limit of {MAX_BATCH}", req.names.len())}),
1088 ),
1089 );
1090 }
1091
1092 tracing::debug!(names = ?req.names, "POST /skills/bundle");
1093
1094 let mut result = serde_json::Map::new();
1095 let mut missing: Vec<String> = Vec::new();
1096
1097 for name in &req.names {
1098 let files = match state.skill_registry.bundle_files(name) {
1099 Ok(f) => f,
1100 Err(_) => {
1101 missing.push(name.clone());
1102 continue;
1103 }
1104 };
1105
1106 let mut file_map = serde_json::Map::new();
1107 for (path, data) in &files {
1108 match std::str::from_utf8(data) {
1109 Ok(text) => {
1110 file_map.insert(path.clone(), Value::String(text.to_string()));
1111 }
1112 Err(_) => {
1113 use base64::Engine;
1114 let encoded = base64::engine::general_purpose::STANDARD.encode(data);
1115 file_map.insert(path.clone(), serde_json::json!({"base64": encoded}));
1116 }
1117 }
1118 }
1119
1120 result.insert(name.clone(), serde_json::json!({ "files": file_map }));
1121 }
1122
1123 (
1124 StatusCode::OK,
1125 Json(serde_json::json!({ "skills": result, "missing": missing })),
1126 )
1127}
1128
1129async fn handle_skills_resolve(
1130 State(state): State<Arc<ProxyState>>,
1131 Json(req): Json<SkillResolveRequest>,
1132) -> impl IntoResponse {
1133 tracing::debug!(scopes = ?req.scopes, include_content = req.include_content, "POST /skills/resolve");
1134
1135 let include_content = req.include_content;
1136 let scopes = ScopeConfig {
1137 scopes: req.scopes,
1138 sub: String::new(),
1139 expires_at: 0,
1140 rate_config: None,
1141 };
1142
1143 let resolved = skill::resolve_skills(&state.skill_registry, &state.registry, &scopes);
1144
1145 let json: Vec<Value> = resolved
1146 .iter()
1147 .map(|s| {
1148 let mut entry = serde_json::json!({
1149 "name": s.name,
1150 "version": s.version,
1151 "description": s.description,
1152 "tools": s.tools,
1153 "providers": s.providers,
1154 "categories": s.categories,
1155 });
1156 if include_content {
1157 if let Ok(content) = state.skill_registry.read_content(&s.name) {
1158 entry["content"] = Value::String(content);
1159 }
1160 }
1161 entry
1162 })
1163 .collect();
1164
1165 (StatusCode::OK, Json(Value::Array(json)))
1166}
1167
1168async fn auth_middleware(
1176 State(state): State<Arc<ProxyState>>,
1177 mut req: HttpRequest<Body>,
1178 next: Next,
1179) -> Result<Response, StatusCode> {
1180 let path = req.uri().path();
1181
1182 if path == "/health" || path == "/.well-known/jwks.json" {
1184 return Ok(next.run(req).await);
1185 }
1186
1187 let jwt_config = match &state.jwt_config {
1189 Some(c) => c,
1190 None => return Ok(next.run(req).await),
1191 };
1192
1193 let auth_header = req
1195 .headers()
1196 .get("authorization")
1197 .and_then(|v| v.to_str().ok());
1198
1199 let token = match auth_header {
1200 Some(header) if header.starts_with("Bearer ") => &header[7..],
1201 _ => return Err(StatusCode::UNAUTHORIZED),
1202 };
1203
1204 match jwt::validate(token, jwt_config) {
1206 Ok(claims) => {
1207 tracing::debug!(sub = %claims.sub, scopes = %claims.scope, "JWT validated");
1208 req.extensions_mut().insert(claims);
1209 Ok(next.run(req).await)
1210 }
1211 Err(e) => {
1212 tracing::debug!(error = %e, "JWT validation failed");
1213 Err(StatusCode::UNAUTHORIZED)
1214 }
1215 }
1216}
1217
1218pub fn build_router(state: Arc<ProxyState>) -> Router {
1222 Router::new()
1223 .route("/call", post(handle_call))
1224 .route("/help", post(handle_help))
1225 .route("/mcp", post(handle_mcp))
1226 .route("/tools", get(handle_tools_list))
1227 .route("/tools/{name}", get(handle_tool_info))
1228 .route("/skills", get(handle_skills_list))
1229 .route("/skills/resolve", post(handle_skills_resolve))
1230 .route("/skills/bundle", post(handle_skills_bundle_batch))
1231 .route("/skills/{name}", get(handle_skill_detail))
1232 .route("/skills/{name}/bundle", get(handle_skill_bundle))
1233 .route("/health", get(handle_health))
1234 .route("/.well-known/jwks.json", get(handle_jwks))
1235 .layer(middleware::from_fn_with_state(
1236 state.clone(),
1237 auth_middleware,
1238 ))
1239 .with_state(state)
1240}
1241
1242pub async fn run(
1246 port: u16,
1247 bind_addr: Option<String>,
1248 ati_dir: PathBuf,
1249 _verbose: bool,
1250 env_keys: bool,
1251) -> Result<(), Box<dyn std::error::Error>> {
1252 let manifests_dir = ati_dir.join("manifests");
1254 let mut registry = ManifestRegistry::load(&manifests_dir)?;
1255 let provider_count = registry.list_providers().len();
1256
1257 let keyring_source;
1259 let keyring = if env_keys {
1260 let kr = Keyring::from_env();
1262 let key_names = kr.key_names();
1263 tracing::info!(
1264 count = key_names.len(),
1265 "loaded API keys from ATI_KEY_* env vars"
1266 );
1267 for name in &key_names {
1268 tracing::debug!(key = %name, "env key loaded");
1269 }
1270 keyring_source = "env-vars (ATI_KEY_*)";
1271 kr
1272 } else {
1273 let keyring_path = ati_dir.join("keyring.enc");
1275 if keyring_path.exists() {
1276 if let Ok(kr) = Keyring::load(&keyring_path) {
1277 keyring_source = "keyring.enc (sealed key)";
1278 kr
1279 } else if let Ok(kr) = Keyring::load_local(&keyring_path, &ati_dir) {
1280 keyring_source = "keyring.enc (persistent key)";
1281 kr
1282 } else {
1283 tracing::warn!("keyring.enc exists but could not be decrypted");
1284 keyring_source = "empty (decryption failed)";
1285 Keyring::empty()
1286 }
1287 } else {
1288 let creds_path = ati_dir.join("credentials");
1289 if creds_path.exists() {
1290 match Keyring::load_credentials(&creds_path) {
1291 Ok(kr) => {
1292 keyring_source = "credentials (plaintext)";
1293 kr
1294 }
1295 Err(e) => {
1296 tracing::warn!(error = %e, "failed to load credentials");
1297 keyring_source = "empty (credentials error)";
1298 Keyring::empty()
1299 }
1300 }
1301 } else {
1302 tracing::warn!("no keyring.enc or credentials found — running without API keys");
1303 tracing::warn!("tools requiring authentication will fail");
1304 keyring_source = "empty (no auth)";
1305 Keyring::empty()
1306 }
1307 }
1308 };
1309
1310 mcp_client::discover_all_mcp_tools(&mut registry, &keyring).await;
1313
1314 let tool_count = registry.list_public_tools().len();
1315
1316 let mcp_providers: Vec<(String, String)> = registry
1318 .list_mcp_providers()
1319 .iter()
1320 .map(|p| (p.name.clone(), p.mcp_transport_type().to_string()))
1321 .collect();
1322 let mcp_count = mcp_providers.len();
1323 let openapi_providers: Vec<String> = registry
1324 .list_openapi_providers()
1325 .iter()
1326 .map(|p| p.name.clone())
1327 .collect();
1328 let openapi_count = openapi_providers.len();
1329
1330 let skills_dir = ati_dir.join("skills");
1332 let mut skill_registry = SkillRegistry::load(&skills_dir).unwrap_or_else(|e| {
1333 tracing::warn!(error = %e, "failed to load skills");
1334 SkillRegistry::load(std::path::Path::new("/nonexistent-fallback")).unwrap()
1335 });
1336
1337 if let Ok(registry_url) = std::env::var("ATI_SKILL_REGISTRY") {
1339 if let Some(bucket) = registry_url.strip_prefix("gcs://") {
1340 let cred_key = "gcp_credentials";
1341 if let Some(cred_json) = keyring.get(cred_key) {
1342 match crate::core::gcs::GcsClient::new(bucket.to_string(), cred_json) {
1343 Ok(client) => match crate::core::gcs::GcsSkillSource::load(&client).await {
1344 Ok(gcs_source) => {
1345 let gcs_count = gcs_source.skill_count();
1346 skill_registry.merge(gcs_source);
1347 tracing::info!(
1348 bucket = %bucket,
1349 skills = gcs_count,
1350 "loaded skills from GCS registry"
1351 );
1352 }
1353 Err(e) => {
1354 tracing::warn!(error = %e, bucket = %bucket, "failed to load GCS skills");
1355 }
1356 },
1357 Err(e) => {
1358 tracing::warn!(error = %e, "failed to init GCS client");
1359 }
1360 }
1361 } else {
1362 tracing::warn!(
1363 key = %cred_key,
1364 "ATI_SKILL_REGISTRY set but GCS credentials not found in keyring"
1365 );
1366 }
1367 } else {
1368 tracing::warn!(
1369 url = %registry_url,
1370 "unsupported skill registry scheme (only gcs:// is supported)"
1371 );
1372 }
1373 }
1374
1375 let skill_count = skill_registry.skill_count();
1376
1377 let jwt_config = match jwt::config_from_env() {
1379 Ok(config) => config,
1380 Err(e) => {
1381 tracing::warn!(error = %e, "JWT config error");
1382 None
1383 }
1384 };
1385
1386 let auth_status = if jwt_config.is_some() {
1387 "JWT enabled"
1388 } else {
1389 "DISABLED (no JWT keys configured)"
1390 };
1391
1392 let jwks_json = jwt_config.as_ref().and_then(|config| {
1394 config
1395 .public_key_pem
1396 .as_ref()
1397 .and_then(|pem| jwt::public_key_to_jwks(pem, config.algorithm, "ati-proxy-1").ok())
1398 });
1399
1400 let state = Arc::new(ProxyState {
1401 registry,
1402 skill_registry,
1403 keyring,
1404 jwt_config,
1405 jwks_json,
1406 auth_cache: AuthCache::new(),
1407 });
1408
1409 let app = build_router(state);
1410
1411 let addr: SocketAddr = if let Some(ref bind) = bind_addr {
1412 format!("{bind}:{port}").parse()?
1413 } else {
1414 SocketAddr::from(([127, 0, 0, 1], port))
1415 };
1416
1417 tracing::info!(
1418 version = env!("CARGO_PKG_VERSION"),
1419 %addr,
1420 auth = auth_status,
1421 ati_dir = %ati_dir.display(),
1422 tools = tool_count,
1423 providers = provider_count,
1424 mcp = mcp_count,
1425 openapi = openapi_count,
1426 skills = skill_count,
1427 keyring = keyring_source,
1428 "ATI proxy server starting"
1429 );
1430 for (name, transport) in &mcp_providers {
1431 tracing::info!(provider = %name, transport = %transport, "MCP provider");
1432 }
1433 for name in &openapi_providers {
1434 tracing::info!(provider = %name, "OpenAPI provider");
1435 }
1436
1437 let listener = tokio::net::TcpListener::bind(addr).await?;
1438 axum::serve(listener, app).await?;
1439
1440 Ok(())
1441}
1442
1443fn write_proxy_audit(
1445 call_req: &CallRequest,
1446 agent_sub: &str,
1447 duration: std::time::Duration,
1448 error: Option<&str>,
1449) {
1450 let entry = crate::core::audit::AuditEntry {
1451 ts: chrono::Utc::now().to_rfc3339(),
1452 tool: call_req.tool_name.clone(),
1453 args: crate::core::audit::sanitize_args(&call_req.args),
1454 status: if error.is_some() {
1455 crate::core::audit::AuditStatus::Error
1456 } else {
1457 crate::core::audit::AuditStatus::Ok
1458 },
1459 duration_ms: duration.as_millis() as u64,
1460 agent_sub: agent_sub.to_string(),
1461 error: error.map(|s| s.to_string()),
1462 exit_code: None,
1463 };
1464 let _ = crate::core::audit::append(&entry);
1465}
1466
1467const HELP_SYSTEM_PROMPT: &str = r#"You are a helpful assistant for an AI agent that uses external tools via the `ati` CLI.
1470
1471## Available Tools
1472{tools}
1473
1474{skills_section}
1475
1476Answer the agent's question naturally, like a knowledgeable colleague would. Keep it short but useful:
1477
1478- Explain which tools to use and why, with `ati run` commands showing realistic parameter values
1479- If multiple steps are needed, walk through them briefly in order
1480- Mention important gotchas or parameter choices that matter
1481- If skills are relevant, suggest `ati skill show <name>` for the full methodology
1482
1483Keep your answer concise — a few short paragraphs with embedded code blocks. Only recommend tools from the list above."#;
1484
1485fn build_tool_context(
1486 tools: &[(
1487 &crate::core::manifest::Provider,
1488 &crate::core::manifest::Tool,
1489 )],
1490) -> String {
1491 let mut summaries = Vec::new();
1492 for (provider, tool) in tools {
1493 let mut summary = if let Some(cat) = &provider.category {
1494 format!(
1495 "- **{}** (provider: {}, category: {}): {}",
1496 tool.name, provider.name, cat, tool.description
1497 )
1498 } else {
1499 format!(
1500 "- **{}** (provider: {}): {}",
1501 tool.name, provider.name, tool.description
1502 )
1503 };
1504 if !tool.tags.is_empty() {
1505 summary.push_str(&format!("\n Tags: {}", tool.tags.join(", ")));
1506 }
1507 if provider.is_cli() && tool.input_schema.is_none() {
1509 let cmd = provider.cli_command.as_deref().unwrap_or("?");
1510 summary.push_str(&format!(
1511 "\n Usage: `ati run {} -- <args>` (passthrough to `{}`)",
1512 tool.name, cmd
1513 ));
1514 } else if let Some(schema) = &tool.input_schema {
1515 if let Some(props) = schema.get("properties") {
1516 if let Some(obj) = props.as_object() {
1517 let params: Vec<String> = obj
1518 .iter()
1519 .filter(|(_, v)| {
1520 v.get("x-ati-param-location").is_none()
1521 || v.get("description").is_some()
1522 })
1523 .map(|(k, v)| {
1524 let type_str =
1525 v.get("type").and_then(|t| t.as_str()).unwrap_or("string");
1526 let desc = v.get("description").and_then(|d| d.as_str()).unwrap_or("");
1527 format!(" --{k} ({type_str}): {desc}")
1528 })
1529 .collect();
1530 if !params.is_empty() {
1531 summary.push_str("\n Parameters:\n");
1532 summary.push_str(¶ms.join("\n"));
1533 }
1534 }
1535 }
1536 }
1537 summaries.push(summary);
1538 }
1539 summaries.join("\n\n")
1540}
1541
1542fn build_scoped_prompt(
1546 scope_name: &str,
1547 registry: &ManifestRegistry,
1548 skills_section: &str,
1549) -> Option<String> {
1550 if let Some((provider, tool)) = registry.get_tool(scope_name) {
1552 let mut details = format!(
1553 "**Name**: `{}`\n**Provider**: {} (handler: {})\n**Description**: {}\n",
1554 tool.name, provider.name, provider.handler, tool.description
1555 );
1556 if let Some(cat) = &provider.category {
1557 details.push_str(&format!("**Category**: {}\n", cat));
1558 }
1559 if provider.is_cli() {
1560 let cmd = provider.cli_command.as_deref().unwrap_or("?");
1561 details.push_str(&format!(
1562 "\n**Usage**: `ati run {} -- <args>` (passthrough to `{}`)\n",
1563 tool.name, cmd
1564 ));
1565 } else if let Some(schema) = &tool.input_schema {
1566 if let Some(props) = schema.get("properties").and_then(|p| p.as_object()) {
1567 let required: Vec<String> = schema
1568 .get("required")
1569 .and_then(|r| r.as_array())
1570 .map(|arr| {
1571 arr.iter()
1572 .filter_map(|v| v.as_str().map(|s| s.to_string()))
1573 .collect()
1574 })
1575 .unwrap_or_default();
1576 details.push_str("\n**Parameters**:\n");
1577 for (key, val) in props {
1578 let type_str = val.get("type").and_then(|t| t.as_str()).unwrap_or("string");
1579 let desc = val
1580 .get("description")
1581 .and_then(|d| d.as_str())
1582 .unwrap_or("");
1583 let req = if required.contains(key) {
1584 " **(required)**"
1585 } else {
1586 ""
1587 };
1588 details.push_str(&format!("- `--{key}` ({type_str}{req}): {desc}\n"));
1589 }
1590 }
1591 }
1592
1593 let prompt = format!(
1594 "You are an expert assistant for the `{}` tool, accessed via the `ati` CLI.\n\n\
1595 ## Tool Details\n{}\n\n{}\n\n\
1596 Answer the agent's question about this specific tool. Provide exact commands, explain flags and options, and give practical examples. Be concise and actionable.",
1597 tool.name, details, skills_section
1598 );
1599 return Some(prompt);
1600 }
1601
1602 if registry.has_provider(scope_name) {
1604 let tools = registry.tools_by_provider(scope_name);
1605 if tools.is_empty() {
1606 return None;
1607 }
1608 let tools_context = build_tool_context(&tools);
1609 let prompt = format!(
1610 "You are an expert assistant for the `{}` provider's tools, accessed via the `ati` CLI.\n\n\
1611 ## Tools in provider `{}`\n{}\n\n{}\n\n\
1612 Answer the agent's question about these tools. Provide exact `ati run` commands, explain parameters, and give practical examples. Be concise and actionable.",
1613 scope_name, scope_name, tools_context, skills_section
1614 );
1615 return Some(prompt);
1616 }
1617
1618 None
1619}