1use axum::{
8 body::Body,
9 extract::State,
10 http::{Request as HttpRequest, StatusCode},
11 middleware::{self, Next},
12 response::{IntoResponse, Response},
13 routing::{get, post},
14 Json, Router,
15};
16use serde::{Deserialize, Serialize};
17use serde_json::Value;
18use std::collections::HashMap;
19use std::net::SocketAddr;
20use std::path::PathBuf;
21use std::sync::Arc;
22
23use crate::core::auth_generator::{AuthCache, GenContext};
24use crate::core::http;
25use crate::core::jwt::{self, JwtConfig, TokenClaims};
26use crate::core::keyring::Keyring;
27use crate::core::manifest::ManifestRegistry;
28use crate::core::mcp_client;
29use crate::core::response;
30use crate::core::scope::ScopeConfig;
31use crate::core::skill::{self, SkillRegistry};
32use crate::core::xai;
33
34pub struct ProxyState {
36 pub registry: ManifestRegistry,
37 pub skill_registry: SkillRegistry,
38 pub keyring: Keyring,
39 pub jwt_config: Option<JwtConfig>,
41 pub jwks_json: Option<Value>,
43 pub auth_cache: AuthCache,
45}
46
47#[derive(Debug, Deserialize)]
50pub struct CallRequest {
51 pub tool_name: String,
52 #[serde(default = "default_args")]
56 pub args: Value,
57 #[serde(default)]
60 pub raw_args: Option<Vec<String>>,
61}
62
63fn default_args() -> Value {
64 Value::Object(serde_json::Map::new())
65}
66
67impl CallRequest {
68 fn args_as_map(&self) -> HashMap<String, Value> {
72 match &self.args {
73 Value::Object(map) => map.iter().map(|(k, v)| (k.clone(), v.clone())).collect(),
74 _ => HashMap::new(),
75 }
76 }
77
78 fn args_as_positional(&self) -> Vec<String> {
81 if let Some(ref raw) = self.raw_args {
83 return raw.clone();
84 }
85 match &self.args {
86 Value::Array(arr) => arr
88 .iter()
89 .map(|v| match v {
90 Value::String(s) => s.clone(),
91 other => other.to_string(),
92 })
93 .collect(),
94 Value::String(s) => s.split_whitespace().map(String::from).collect(),
96 Value::Object(map) => {
98 if let Some(Value::Array(pos)) = map.get("_positional") {
99 return pos
100 .iter()
101 .map(|v| match v {
102 Value::String(s) => s.clone(),
103 other => other.to_string(),
104 })
105 .collect();
106 }
107 let mut result = Vec::new();
109 for (k, v) in map {
110 result.push(format!("--{k}"));
111 match v {
112 Value::String(s) => result.push(s.clone()),
113 Value::Bool(true) => {} other => result.push(other.to_string()),
115 }
116 }
117 result
118 }
119 _ => Vec::new(),
120 }
121 }
122}
123
124#[derive(Debug, Serialize)]
125pub struct CallResponse {
126 pub result: Value,
127 #[serde(skip_serializing_if = "Option::is_none")]
128 pub error: Option<String>,
129}
130
131#[derive(Debug, Deserialize)]
132pub struct HelpRequest {
133 pub query: String,
134 #[serde(default)]
135 pub tool: Option<String>,
136}
137
138#[derive(Debug, Serialize)]
139pub struct HelpResponse {
140 pub content: String,
141 #[serde(skip_serializing_if = "Option::is_none")]
142 pub error: Option<String>,
143}
144
145#[derive(Debug, Serialize)]
146pub struct HealthResponse {
147 pub status: String,
148 pub version: String,
149 pub tools: usize,
150 pub providers: usize,
151 pub skills: usize,
152 pub auth: String,
153}
154
155#[derive(Debug, Deserialize)]
158pub struct SkillsQuery {
159 #[serde(default)]
160 pub category: Option<String>,
161 #[serde(default)]
162 pub provider: Option<String>,
163 #[serde(default)]
164 pub tool: Option<String>,
165 #[serde(default)]
166 pub search: Option<String>,
167}
168
169#[derive(Debug, Deserialize)]
170pub struct SkillDetailQuery {
171 #[serde(default)]
172 pub meta: Option<bool>,
173 #[serde(default)]
174 pub refs: Option<bool>,
175}
176
177#[derive(Debug, Deserialize)]
178pub struct SkillResolveRequest {
179 pub scopes: Vec<String>,
180 #[serde(default)]
182 pub include_content: bool,
183}
184
185#[derive(Debug, Deserialize)]
186pub struct SkillBundleBatchRequest {
187 pub names: Vec<String>,
188}
189
190#[derive(Debug, Deserialize)]
193pub struct ToolsQuery {
194 #[serde(default)]
195 pub provider: Option<String>,
196 #[serde(default)]
197 pub search: Option<String>,
198}
199
200async fn handle_call(
203 State(state): State<Arc<ProxyState>>,
204 req: HttpRequest<Body>,
205) -> impl IntoResponse {
206 let claims = req.extensions().get::<TokenClaims>().cloned();
208
209 let body_bytes = match axum::body::to_bytes(req.into_body(), 10 * 1024 * 1024).await {
211 Ok(b) => b,
212 Err(e) => {
213 return (
214 StatusCode::BAD_REQUEST,
215 Json(CallResponse {
216 result: Value::Null,
217 error: Some(format!("Failed to read request body: {e}")),
218 }),
219 );
220 }
221 };
222
223 let call_req: CallRequest = match serde_json::from_slice(&body_bytes) {
224 Ok(r) => r,
225 Err(e) => {
226 return (
227 StatusCode::UNPROCESSABLE_ENTITY,
228 Json(CallResponse {
229 result: Value::Null,
230 error: Some(format!("Invalid request: {e}")),
231 }),
232 );
233 }
234 };
235
236 tracing::debug!(
237 tool = %call_req.tool_name,
238 args = ?call_req.args,
239 "POST /call"
240 );
241
242 let (provider, tool) = match state.registry.get_tool(&call_req.tool_name) {
244 Some(pt) => pt,
245 None => {
246 return (
247 StatusCode::NOT_FOUND,
248 Json(CallResponse {
249 result: Value::Null,
250 error: Some(format!("Unknown tool: '{}'", call_req.tool_name)),
251 }),
252 );
253 }
254 };
255
256 if let Some(tool_scope) = &tool.scope {
258 let scopes = match &claims {
259 Some(c) => ScopeConfig::from_jwt(c),
260 None if state.jwt_config.is_none() => ScopeConfig::unrestricted(), None => {
262 return (
263 StatusCode::FORBIDDEN,
264 Json(CallResponse {
265 result: Value::Null,
266 error: Some("Authentication required — no JWT provided".into()),
267 }),
268 );
269 }
270 };
271
272 if let Err(e) = scopes.check_access(&call_req.tool_name, tool_scope) {
273 return (
274 StatusCode::FORBIDDEN,
275 Json(CallResponse {
276 result: Value::Null,
277 error: Some(format!("Access denied: {e}")),
278 }),
279 );
280 }
281 }
282
283 {
285 let scopes = match &claims {
286 Some(c) => ScopeConfig::from_jwt(c),
287 None => ScopeConfig::unrestricted(),
288 };
289 if let Some(ref rate_config) = scopes.rate_config {
290 if let Err(e) = crate::core::rate::check_and_record(&call_req.tool_name, rate_config) {
291 return (
292 StatusCode::TOO_MANY_REQUESTS,
293 Json(CallResponse {
294 result: Value::Null,
295 error: Some(format!("{e}")),
296 }),
297 );
298 }
299 }
300 }
301
302 let gen_ctx = GenContext {
304 jwt_sub: claims
305 .as_ref()
306 .map(|c| c.sub.clone())
307 .unwrap_or_else(|| "dev".into()),
308 jwt_scope: claims
309 .as_ref()
310 .map(|c| c.scope.clone())
311 .unwrap_or_else(|| "*".into()),
312 tool_name: call_req.tool_name.clone(),
313 timestamp: crate::core::jwt::now_secs(),
314 };
315
316 let agent_sub = claims.as_ref().map(|c| c.sub.clone()).unwrap_or_default();
318 let start = std::time::Instant::now();
319
320 let response = match provider.handler.as_str() {
321 "mcp" => {
322 let args_map = call_req.args_as_map();
323 match mcp_client::execute_with_gen(
324 provider,
325 &call_req.tool_name,
326 &args_map,
327 &state.keyring,
328 Some(&gen_ctx),
329 Some(&state.auth_cache),
330 )
331 .await
332 {
333 Ok(result) => (
334 StatusCode::OK,
335 Json(CallResponse {
336 result,
337 error: None,
338 }),
339 ),
340 Err(e) => (
341 StatusCode::BAD_GATEWAY,
342 Json(CallResponse {
343 result: Value::Null,
344 error: Some(format!("MCP error: {e}")),
345 }),
346 ),
347 }
348 }
349 "cli" => {
350 let positional = call_req.args_as_positional();
351 match crate::core::cli_executor::execute_with_gen(
352 provider,
353 &positional,
354 &state.keyring,
355 Some(&gen_ctx),
356 Some(&state.auth_cache),
357 )
358 .await
359 {
360 Ok(result) => (
361 StatusCode::OK,
362 Json(CallResponse {
363 result,
364 error: None,
365 }),
366 ),
367 Err(e) => (
368 StatusCode::BAD_GATEWAY,
369 Json(CallResponse {
370 result: Value::Null,
371 error: Some(format!("CLI error: {e}")),
372 }),
373 ),
374 }
375 }
376 _ => {
377 let args_map = call_req.args_as_map();
378 let raw_response = match match provider.handler.as_str() {
379 "xai" => xai::execute_xai_tool(provider, tool, &args_map, &state.keyring).await,
380 _ => {
381 http::execute_tool_with_gen(
382 provider,
383 tool,
384 &args_map,
385 &state.keyring,
386 Some(&gen_ctx),
387 Some(&state.auth_cache),
388 )
389 .await
390 }
391 } {
392 Ok(resp) => resp,
393 Err(e) => {
394 let duration = start.elapsed();
395 write_proxy_audit(&call_req, &agent_sub, duration, Some(&e.to_string()));
396 return (
397 StatusCode::BAD_GATEWAY,
398 Json(CallResponse {
399 result: Value::Null,
400 error: Some(format!("Upstream API error: {e}")),
401 }),
402 );
403 }
404 };
405
406 let processed = match response::process_response(&raw_response, tool.response.as_ref())
407 {
408 Ok(p) => p,
409 Err(e) => {
410 let duration = start.elapsed();
411 write_proxy_audit(&call_req, &agent_sub, duration, Some(&e.to_string()));
412 return (
413 StatusCode::INTERNAL_SERVER_ERROR,
414 Json(CallResponse {
415 result: raw_response,
416 error: Some(format!("Response processing error: {e}")),
417 }),
418 );
419 }
420 };
421
422 (
423 StatusCode::OK,
424 Json(CallResponse {
425 result: processed,
426 error: None,
427 }),
428 )
429 }
430 };
431
432 let duration = start.elapsed();
433 let error_msg = response.1.error.as_deref();
434 write_proxy_audit(&call_req, &agent_sub, duration, error_msg);
435
436 response
437}
438
439async fn handle_help(
440 State(state): State<Arc<ProxyState>>,
441 Json(req): Json<HelpRequest>,
442) -> impl IntoResponse {
443 tracing::debug!(query = %req.query, tool = ?req.tool, "POST /help");
444
445 let (llm_provider, llm_tool) = match state.registry.get_tool("_chat_completion") {
446 Some(pt) => pt,
447 None => {
448 return (
449 StatusCode::SERVICE_UNAVAILABLE,
450 Json(HelpResponse {
451 content: String::new(),
452 error: Some("No _llm.toml manifest found. Proxy help requires a configured LLM provider.".into()),
453 }),
454 );
455 }
456 };
457
458 let api_key = match llm_provider
459 .auth_key_name
460 .as_deref()
461 .and_then(|k| state.keyring.get(k))
462 {
463 Some(key) => key.to_string(),
464 None => {
465 return (
466 StatusCode::SERVICE_UNAVAILABLE,
467 Json(HelpResponse {
468 content: String::new(),
469 error: Some("LLM API key not found in keyring".into()),
470 }),
471 );
472 }
473 };
474
475 let scopes = ScopeConfig::unrestricted();
476 let resolved_skills = skill::resolve_skills(&state.skill_registry, &state.registry, &scopes);
477 let skills_section = if resolved_skills.is_empty() {
478 String::new()
479 } else {
480 format!(
481 "## Available Skills (methodology guides)\n{}",
482 skill::build_skill_context(&resolved_skills)
483 )
484 };
485
486 let system_prompt = if let Some(ref tool_name) = req.tool {
488 match build_scoped_prompt(tool_name, &state.registry, &skills_section) {
490 Some(prompt) => prompt,
491 None => {
492 tracing::debug!(scope = %tool_name, "scope not found, falling back to unscoped");
494 let all_tools = state.registry.list_public_tools();
495 let tools_context = build_tool_context(&all_tools);
496 HELP_SYSTEM_PROMPT
497 .replace("{tools}", &tools_context)
498 .replace("{skills_section}", &skills_section)
499 }
500 }
501 } else {
502 let all_tools = state.registry.list_public_tools();
503 let tools_context = build_tool_context(&all_tools);
504 HELP_SYSTEM_PROMPT
505 .replace("{tools}", &tools_context)
506 .replace("{skills_section}", &skills_section)
507 };
508
509 let request_body = serde_json::json!({
510 "model": "zai-glm-4.7",
511 "messages": [
512 {"role": "system", "content": system_prompt},
513 {"role": "user", "content": req.query}
514 ],
515 "max_completion_tokens": 1536,
516 "temperature": 0.3
517 });
518
519 let client = reqwest::Client::new();
520 let url = format!(
521 "{}{}",
522 llm_provider.base_url.trim_end_matches('/'),
523 llm_tool.endpoint
524 );
525
526 let response = match client
527 .post(&url)
528 .bearer_auth(&api_key)
529 .json(&request_body)
530 .send()
531 .await
532 {
533 Ok(r) => r,
534 Err(e) => {
535 return (
536 StatusCode::BAD_GATEWAY,
537 Json(HelpResponse {
538 content: String::new(),
539 error: Some(format!("LLM request failed: {e}")),
540 }),
541 );
542 }
543 };
544
545 if !response.status().is_success() {
546 let status = response.status();
547 let body = response.text().await.unwrap_or_default();
548 return (
549 StatusCode::BAD_GATEWAY,
550 Json(HelpResponse {
551 content: String::new(),
552 error: Some(format!("LLM API error ({status}): {body}")),
553 }),
554 );
555 }
556
557 let body: Value = match response.json().await {
558 Ok(b) => b,
559 Err(e) => {
560 return (
561 StatusCode::INTERNAL_SERVER_ERROR,
562 Json(HelpResponse {
563 content: String::new(),
564 error: Some(format!("Failed to parse LLM response: {e}")),
565 }),
566 );
567 }
568 };
569
570 let content = body
571 .pointer("/choices/0/message/content")
572 .and_then(|c| c.as_str())
573 .unwrap_or("No response from LLM")
574 .to_string();
575
576 (
577 StatusCode::OK,
578 Json(HelpResponse {
579 content,
580 error: None,
581 }),
582 )
583}
584
585async fn handle_health(State(state): State<Arc<ProxyState>>) -> impl IntoResponse {
586 let auth = if state.jwt_config.is_some() {
587 "jwt"
588 } else {
589 "disabled"
590 };
591
592 Json(HealthResponse {
593 status: "ok".into(),
594 version: env!("CARGO_PKG_VERSION").into(),
595 tools: state.registry.list_public_tools().len(),
596 providers: state.registry.list_providers().len(),
597 skills: state.skill_registry.skill_count(),
598 auth: auth.into(),
599 })
600}
601
602async fn handle_jwks(State(state): State<Arc<ProxyState>>) -> impl IntoResponse {
604 match &state.jwks_json {
605 Some(jwks) => (StatusCode::OK, Json(jwks.clone())),
606 None => (
607 StatusCode::NOT_FOUND,
608 Json(serde_json::json!({"error": "JWKS not configured"})),
609 ),
610 }
611}
612
613async fn handle_mcp(
618 State(state): State<Arc<ProxyState>>,
619 Json(msg): Json<Value>,
620) -> impl IntoResponse {
621 let method = msg.get("method").and_then(|m| m.as_str()).unwrap_or("");
622 let id = msg.get("id").cloned();
623
624 tracing::debug!(%method, "POST /mcp");
625
626 match method {
627 "initialize" => {
628 let result = serde_json::json!({
629 "protocolVersion": "2025-03-26",
630 "capabilities": {
631 "tools": { "listChanged": false }
632 },
633 "serverInfo": {
634 "name": "ati-proxy",
635 "version": env!("CARGO_PKG_VERSION")
636 }
637 });
638 jsonrpc_success(id, result)
639 }
640
641 "notifications/initialized" => (StatusCode::ACCEPTED, Json(Value::Null)),
642
643 "tools/list" => {
644 let all_tools = state.registry.list_public_tools();
645 let mcp_tools: Vec<Value> = all_tools
646 .iter()
647 .map(|(_provider, tool)| {
648 serde_json::json!({
649 "name": tool.name,
650 "description": tool.description,
651 "inputSchema": tool.input_schema.clone().unwrap_or(serde_json::json!({
652 "type": "object",
653 "properties": {}
654 }))
655 })
656 })
657 .collect();
658
659 let result = serde_json::json!({
660 "tools": mcp_tools,
661 });
662 jsonrpc_success(id, result)
663 }
664
665 "tools/call" => {
666 let params = msg.get("params").cloned().unwrap_or(Value::Null);
667 let tool_name = params.get("name").and_then(|n| n.as_str()).unwrap_or("");
668 let arguments: HashMap<String, Value> = params
669 .get("arguments")
670 .and_then(|a| serde_json::from_value(a.clone()).ok())
671 .unwrap_or_default();
672
673 if tool_name.is_empty() {
674 return jsonrpc_error(id, -32602, "Missing tool name in params.name");
675 }
676
677 let (provider, _tool) = match state.registry.get_tool(tool_name) {
678 Some(pt) => pt,
679 None => {
680 return jsonrpc_error(id, -32602, &format!("Unknown tool: '{tool_name}'"));
681 }
682 };
683
684 tracing::debug!(%tool_name, provider = %provider.name, "MCP tools/call");
685
686 let mcp_gen_ctx = GenContext {
687 jwt_sub: "dev".into(),
688 jwt_scope: "*".into(),
689 tool_name: tool_name.to_string(),
690 timestamp: crate::core::jwt::now_secs(),
691 };
692
693 let result = if provider.is_mcp() {
694 mcp_client::execute_with_gen(
695 provider,
696 tool_name,
697 &arguments,
698 &state.keyring,
699 Some(&mcp_gen_ctx),
700 Some(&state.auth_cache),
701 )
702 .await
703 } else if provider.is_cli() {
704 let raw: Vec<String> = arguments
706 .iter()
707 .flat_map(|(k, v)| {
708 let val = match v {
709 Value::String(s) => s.clone(),
710 other => other.to_string(),
711 };
712 vec![format!("--{k}"), val]
713 })
714 .collect();
715 crate::core::cli_executor::execute_with_gen(
716 provider,
717 &raw,
718 &state.keyring,
719 Some(&mcp_gen_ctx),
720 Some(&state.auth_cache),
721 )
722 .await
723 .map_err(|e| mcp_client::McpError::Transport(e.to_string()))
724 } else {
725 match match provider.handler.as_str() {
726 "xai" => {
727 xai::execute_xai_tool(provider, _tool, &arguments, &state.keyring).await
728 }
729 _ => {
730 http::execute_tool_with_gen(
731 provider,
732 _tool,
733 &arguments,
734 &state.keyring,
735 Some(&mcp_gen_ctx),
736 Some(&state.auth_cache),
737 )
738 .await
739 }
740 } {
741 Ok(val) => Ok(val),
742 Err(e) => Err(mcp_client::McpError::Transport(e.to_string())),
743 }
744 };
745
746 match result {
747 Ok(value) => {
748 let text = match &value {
749 Value::String(s) => s.clone(),
750 other => serde_json::to_string_pretty(other).unwrap_or_default(),
751 };
752 let mcp_result = serde_json::json!({
753 "content": [{"type": "text", "text": text}],
754 "isError": false,
755 });
756 jsonrpc_success(id, mcp_result)
757 }
758 Err(e) => {
759 let mcp_result = serde_json::json!({
760 "content": [{"type": "text", "text": format!("Error: {e}")}],
761 "isError": true,
762 });
763 jsonrpc_success(id, mcp_result)
764 }
765 }
766 }
767
768 _ => jsonrpc_error(id, -32601, &format!("Method not found: '{method}'")),
769 }
770}
771
772fn jsonrpc_success(id: Option<Value>, result: Value) -> (StatusCode, Json<Value>) {
773 (
774 StatusCode::OK,
775 Json(serde_json::json!({
776 "jsonrpc": "2.0",
777 "id": id,
778 "result": result,
779 })),
780 )
781}
782
783fn jsonrpc_error(id: Option<Value>, code: i64, message: &str) -> (StatusCode, Json<Value>) {
784 (
785 StatusCode::OK,
786 Json(serde_json::json!({
787 "jsonrpc": "2.0",
788 "id": id,
789 "error": {
790 "code": code,
791 "message": message,
792 }
793 })),
794 )
795}
796
797async fn handle_tools_list(
803 State(state): State<Arc<ProxyState>>,
804 axum::extract::Query(query): axum::extract::Query<ToolsQuery>,
805) -> impl IntoResponse {
806 tracing::debug!(
807 provider = ?query.provider,
808 search = ?query.search,
809 "GET /tools"
810 );
811
812 let all_tools = state.registry.list_public_tools();
813
814 let tools: Vec<Value> = all_tools
815 .iter()
816 .filter(|(provider, tool)| {
817 if let Some(ref p) = query.provider {
818 if provider.name != *p {
819 return false;
820 }
821 }
822 if let Some(ref q) = query.search {
823 let q = q.to_lowercase();
824 let name_match = tool.name.to_lowercase().contains(&q);
825 let desc_match = tool.description.to_lowercase().contains(&q);
826 let tag_match = tool.tags.iter().any(|t| t.to_lowercase().contains(&q));
827 if !name_match && !desc_match && !tag_match {
828 return false;
829 }
830 }
831 true
832 })
833 .map(|(provider, tool)| {
834 serde_json::json!({
835 "name": tool.name,
836 "description": tool.description,
837 "provider": provider.name,
838 "method": format!("{:?}", tool.method),
839 "tags": tool.tags,
840 "input_schema": tool.input_schema,
841 })
842 })
843 .collect();
844
845 (StatusCode::OK, Json(Value::Array(tools)))
846}
847
848async fn handle_tool_info(
850 State(state): State<Arc<ProxyState>>,
851 axum::extract::Path(name): axum::extract::Path<String>,
852) -> impl IntoResponse {
853 tracing::debug!(tool = %name, "GET /tools/:name");
854
855 match state.registry.get_tool(&name) {
856 Some((provider, tool)) => (
857 StatusCode::OK,
858 Json(serde_json::json!({
859 "name": tool.name,
860 "description": tool.description,
861 "provider": provider.name,
862 "method": format!("{:?}", tool.method),
863 "endpoint": tool.endpoint,
864 "tags": tool.tags,
865 "hint": tool.hint,
866 "input_schema": tool.input_schema,
867 "scope": tool.scope,
868 })),
869 ),
870 None => (
871 StatusCode::NOT_FOUND,
872 Json(serde_json::json!({"error": format!("Tool '{name}' not found")})),
873 ),
874 }
875}
876
877async fn handle_skills_list(
882 State(state): State<Arc<ProxyState>>,
883 axum::extract::Query(query): axum::extract::Query<SkillsQuery>,
884) -> impl IntoResponse {
885 tracing::debug!(
886 category = ?query.category,
887 provider = ?query.provider,
888 tool = ?query.tool,
889 search = ?query.search,
890 "GET /skills"
891 );
892
893 let skills: Vec<&skill::SkillMeta> = if let Some(search_query) = &query.search {
894 state.skill_registry.search(search_query)
895 } else if let Some(cat) = &query.category {
896 state.skill_registry.skills_for_category(cat)
897 } else if let Some(prov) = &query.provider {
898 state.skill_registry.skills_for_provider(prov)
899 } else if let Some(t) = &query.tool {
900 state.skill_registry.skills_for_tool(t)
901 } else {
902 state.skill_registry.list_skills().iter().collect()
903 };
904
905 let json: Vec<Value> = skills
906 .iter()
907 .map(|s| {
908 serde_json::json!({
909 "name": s.name,
910 "version": s.version,
911 "description": s.description,
912 "tools": s.tools,
913 "providers": s.providers,
914 "categories": s.categories,
915 "hint": s.hint,
916 })
917 })
918 .collect();
919
920 (StatusCode::OK, Json(Value::Array(json)))
921}
922
923async fn handle_skill_detail(
924 State(state): State<Arc<ProxyState>>,
925 axum::extract::Path(name): axum::extract::Path<String>,
926 axum::extract::Query(query): axum::extract::Query<SkillDetailQuery>,
927) -> impl IntoResponse {
928 tracing::debug!(%name, meta = ?query.meta, refs = ?query.refs, "GET /skills/:name");
929
930 let skill_meta = match state.skill_registry.get_skill(&name) {
931 Some(s) => s,
932 None => {
933 return (
934 StatusCode::NOT_FOUND,
935 Json(serde_json::json!({"error": format!("Skill '{name}' not found")})),
936 );
937 }
938 };
939
940 if query.meta.unwrap_or(false) {
941 return (
942 StatusCode::OK,
943 Json(serde_json::json!({
944 "name": skill_meta.name,
945 "version": skill_meta.version,
946 "description": skill_meta.description,
947 "author": skill_meta.author,
948 "tools": skill_meta.tools,
949 "providers": skill_meta.providers,
950 "categories": skill_meta.categories,
951 "keywords": skill_meta.keywords,
952 "hint": skill_meta.hint,
953 "depends_on": skill_meta.depends_on,
954 "suggests": skill_meta.suggests,
955 "license": skill_meta.license,
956 "compatibility": skill_meta.compatibility,
957 "allowed_tools": skill_meta.allowed_tools,
958 "format": skill_meta.format,
959 })),
960 );
961 }
962
963 let content = match state.skill_registry.read_content(&name) {
964 Ok(c) => c,
965 Err(e) => {
966 return (
967 StatusCode::INTERNAL_SERVER_ERROR,
968 Json(serde_json::json!({"error": format!("Failed to read skill: {e}")})),
969 );
970 }
971 };
972
973 let mut response = serde_json::json!({
974 "name": skill_meta.name,
975 "version": skill_meta.version,
976 "description": skill_meta.description,
977 "content": content,
978 });
979
980 if query.refs.unwrap_or(false) {
981 if let Ok(refs) = state.skill_registry.list_references(&name) {
982 response["references"] = serde_json::json!(refs);
983 }
984 }
985
986 (StatusCode::OK, Json(response))
987}
988
989async fn handle_skill_bundle(
993 State(state): State<Arc<ProxyState>>,
994 axum::extract::Path(name): axum::extract::Path<String>,
995) -> impl IntoResponse {
996 tracing::debug!(skill = %name, "GET /skills/:name/bundle");
997
998 let files = match state.skill_registry.bundle_files(&name) {
999 Ok(f) => f,
1000 Err(_) => {
1001 return (
1002 StatusCode::NOT_FOUND,
1003 Json(serde_json::json!({"error": format!("Skill '{name}' not found")})),
1004 );
1005 }
1006 };
1007
1008 let mut file_map = serde_json::Map::new();
1010 for (path, data) in &files {
1011 match std::str::from_utf8(data) {
1012 Ok(text) => {
1013 file_map.insert(path.clone(), Value::String(text.to_string()));
1014 }
1015 Err(_) => {
1016 use base64::Engine;
1018 let encoded = base64::engine::general_purpose::STANDARD.encode(data);
1019 file_map.insert(path.clone(), serde_json::json!({"base64": encoded}));
1020 }
1021 }
1022 }
1023
1024 (
1025 StatusCode::OK,
1026 Json(serde_json::json!({
1027 "name": name,
1028 "files": file_map,
1029 })),
1030 )
1031}
1032
1033async fn handle_skills_bundle_batch(
1037 State(state): State<Arc<ProxyState>>,
1038 Json(req): Json<SkillBundleBatchRequest>,
1039) -> impl IntoResponse {
1040 const MAX_BATCH: usize = 50;
1041 if req.names.len() > MAX_BATCH {
1042 return (
1043 StatusCode::BAD_REQUEST,
1044 Json(
1045 serde_json::json!({"error": format!("batch size {} exceeds limit of {MAX_BATCH}", req.names.len())}),
1046 ),
1047 );
1048 }
1049
1050 tracing::debug!(names = ?req.names, "POST /skills/bundle");
1051
1052 let mut result = serde_json::Map::new();
1053 let mut missing: Vec<String> = Vec::new();
1054
1055 for name in &req.names {
1056 let files = match state.skill_registry.bundle_files(name) {
1057 Ok(f) => f,
1058 Err(_) => {
1059 missing.push(name.clone());
1060 continue;
1061 }
1062 };
1063
1064 let mut file_map = serde_json::Map::new();
1065 for (path, data) in &files {
1066 match std::str::from_utf8(data) {
1067 Ok(text) => {
1068 file_map.insert(path.clone(), Value::String(text.to_string()));
1069 }
1070 Err(_) => {
1071 use base64::Engine;
1072 let encoded = base64::engine::general_purpose::STANDARD.encode(data);
1073 file_map.insert(path.clone(), serde_json::json!({"base64": encoded}));
1074 }
1075 }
1076 }
1077
1078 result.insert(name.clone(), serde_json::json!({ "files": file_map }));
1079 }
1080
1081 (
1082 StatusCode::OK,
1083 Json(serde_json::json!({ "skills": result, "missing": missing })),
1084 )
1085}
1086
1087async fn handle_skills_resolve(
1088 State(state): State<Arc<ProxyState>>,
1089 Json(req): Json<SkillResolveRequest>,
1090) -> impl IntoResponse {
1091 tracing::debug!(scopes = ?req.scopes, include_content = req.include_content, "POST /skills/resolve");
1092
1093 let include_content = req.include_content;
1094 let scopes = ScopeConfig {
1095 scopes: req.scopes,
1096 sub: String::new(),
1097 expires_at: 0,
1098 rate_config: None,
1099 };
1100
1101 let resolved = skill::resolve_skills(&state.skill_registry, &state.registry, &scopes);
1102
1103 let json: Vec<Value> = resolved
1104 .iter()
1105 .map(|s| {
1106 let mut entry = serde_json::json!({
1107 "name": s.name,
1108 "version": s.version,
1109 "description": s.description,
1110 "tools": s.tools,
1111 "providers": s.providers,
1112 "categories": s.categories,
1113 });
1114 if include_content {
1115 if let Ok(content) = state.skill_registry.read_content(&s.name) {
1116 entry["content"] = Value::String(content);
1117 }
1118 }
1119 entry
1120 })
1121 .collect();
1122
1123 (StatusCode::OK, Json(Value::Array(json)))
1124}
1125
1126async fn auth_middleware(
1134 State(state): State<Arc<ProxyState>>,
1135 mut req: HttpRequest<Body>,
1136 next: Next,
1137) -> Result<Response, StatusCode> {
1138 let path = req.uri().path();
1139
1140 if path == "/health" || path == "/.well-known/jwks.json" {
1142 return Ok(next.run(req).await);
1143 }
1144
1145 let jwt_config = match &state.jwt_config {
1147 Some(c) => c,
1148 None => return Ok(next.run(req).await),
1149 };
1150
1151 let auth_header = req
1153 .headers()
1154 .get("authorization")
1155 .and_then(|v| v.to_str().ok());
1156
1157 let token = match auth_header {
1158 Some(header) if header.starts_with("Bearer ") => &header[7..],
1159 _ => return Err(StatusCode::UNAUTHORIZED),
1160 };
1161
1162 match jwt::validate(token, jwt_config) {
1164 Ok(claims) => {
1165 tracing::debug!(sub = %claims.sub, scopes = %claims.scope, "JWT validated");
1166 req.extensions_mut().insert(claims);
1167 Ok(next.run(req).await)
1168 }
1169 Err(e) => {
1170 tracing::debug!(error = %e, "JWT validation failed");
1171 Err(StatusCode::UNAUTHORIZED)
1172 }
1173 }
1174}
1175
1176pub fn build_router(state: Arc<ProxyState>) -> Router {
1180 Router::new()
1181 .route("/call", post(handle_call))
1182 .route("/help", post(handle_help))
1183 .route("/mcp", post(handle_mcp))
1184 .route("/tools", get(handle_tools_list))
1185 .route("/tools/{name}", get(handle_tool_info))
1186 .route("/skills", get(handle_skills_list))
1187 .route("/skills/resolve", post(handle_skills_resolve))
1188 .route("/skills/bundle", post(handle_skills_bundle_batch))
1189 .route("/skills/{name}", get(handle_skill_detail))
1190 .route("/skills/{name}/bundle", get(handle_skill_bundle))
1191 .route("/health", get(handle_health))
1192 .route("/.well-known/jwks.json", get(handle_jwks))
1193 .layer(middleware::from_fn_with_state(
1194 state.clone(),
1195 auth_middleware,
1196 ))
1197 .with_state(state)
1198}
1199
1200pub async fn run(
1204 port: u16,
1205 bind_addr: Option<String>,
1206 ati_dir: PathBuf,
1207 _verbose: bool,
1208 env_keys: bool,
1209) -> Result<(), Box<dyn std::error::Error>> {
1210 let manifests_dir = ati_dir.join("manifests");
1212 let registry = ManifestRegistry::load(&manifests_dir)?;
1213
1214 let tool_count = registry.list_public_tools().len();
1215 let provider_count = registry.list_providers().len();
1216
1217 let keyring_source;
1219 let keyring = if env_keys {
1220 let kr = Keyring::from_env();
1222 let key_names = kr.key_names();
1223 tracing::info!(
1224 count = key_names.len(),
1225 "loaded API keys from ATI_KEY_* env vars"
1226 );
1227 for name in &key_names {
1228 tracing::debug!(key = %name, "env key loaded");
1229 }
1230 keyring_source = "env-vars (ATI_KEY_*)";
1231 kr
1232 } else {
1233 let keyring_path = ati_dir.join("keyring.enc");
1235 if keyring_path.exists() {
1236 if let Ok(kr) = Keyring::load(&keyring_path) {
1237 keyring_source = "keyring.enc (sealed key)";
1238 kr
1239 } else if let Ok(kr) = Keyring::load_local(&keyring_path, &ati_dir) {
1240 keyring_source = "keyring.enc (persistent key)";
1241 kr
1242 } else {
1243 tracing::warn!("keyring.enc exists but could not be decrypted");
1244 keyring_source = "empty (decryption failed)";
1245 Keyring::empty()
1246 }
1247 } else {
1248 let creds_path = ati_dir.join("credentials");
1249 if creds_path.exists() {
1250 match Keyring::load_credentials(&creds_path) {
1251 Ok(kr) => {
1252 keyring_source = "credentials (plaintext)";
1253 kr
1254 }
1255 Err(e) => {
1256 tracing::warn!(error = %e, "failed to load credentials");
1257 keyring_source = "empty (credentials error)";
1258 Keyring::empty()
1259 }
1260 }
1261 } else {
1262 tracing::warn!("no keyring.enc or credentials found — running without API keys");
1263 tracing::warn!("tools requiring authentication will fail");
1264 keyring_source = "empty (no auth)";
1265 Keyring::empty()
1266 }
1267 }
1268 };
1269
1270 let mcp_providers: Vec<(String, String)> = registry
1272 .list_mcp_providers()
1273 .iter()
1274 .map(|p| (p.name.clone(), p.mcp_transport_type().to_string()))
1275 .collect();
1276 let mcp_count = mcp_providers.len();
1277 let openapi_providers: Vec<String> = registry
1278 .list_openapi_providers()
1279 .iter()
1280 .map(|p| p.name.clone())
1281 .collect();
1282 let openapi_count = openapi_providers.len();
1283
1284 let skills_dir = ati_dir.join("skills");
1286 let mut skill_registry = SkillRegistry::load(&skills_dir).unwrap_or_else(|e| {
1287 tracing::warn!(error = %e, "failed to load skills");
1288 SkillRegistry::load(std::path::Path::new("/nonexistent-fallback")).unwrap()
1289 });
1290
1291 if let Ok(registry_url) = std::env::var("ATI_SKILL_REGISTRY") {
1293 if let Some(bucket) = registry_url.strip_prefix("gcs://") {
1294 let cred_key = "gcp_credentials";
1295 if let Some(cred_json) = keyring.get(cred_key) {
1296 match crate::core::gcs::GcsClient::new(bucket.to_string(), cred_json) {
1297 Ok(client) => match crate::core::gcs::GcsSkillSource::load(&client).await {
1298 Ok(gcs_source) => {
1299 let gcs_count = gcs_source.skill_count();
1300 skill_registry.merge(gcs_source);
1301 tracing::info!(
1302 bucket = %bucket,
1303 skills = gcs_count,
1304 "loaded skills from GCS registry"
1305 );
1306 }
1307 Err(e) => {
1308 tracing::warn!(error = %e, bucket = %bucket, "failed to load GCS skills");
1309 }
1310 },
1311 Err(e) => {
1312 tracing::warn!(error = %e, "failed to init GCS client");
1313 }
1314 }
1315 } else {
1316 tracing::warn!(
1317 key = %cred_key,
1318 "ATI_SKILL_REGISTRY set but GCS credentials not found in keyring"
1319 );
1320 }
1321 } else {
1322 tracing::warn!(
1323 url = %registry_url,
1324 "unsupported skill registry scheme (only gcs:// is supported)"
1325 );
1326 }
1327 }
1328
1329 let skill_count = skill_registry.skill_count();
1330
1331 let jwt_config = match jwt::config_from_env() {
1333 Ok(config) => config,
1334 Err(e) => {
1335 tracing::warn!(error = %e, "JWT config error");
1336 None
1337 }
1338 };
1339
1340 let auth_status = if jwt_config.is_some() {
1341 "JWT enabled"
1342 } else {
1343 "DISABLED (no JWT keys configured)"
1344 };
1345
1346 let jwks_json = jwt_config.as_ref().and_then(|config| {
1348 config
1349 .public_key_pem
1350 .as_ref()
1351 .and_then(|pem| jwt::public_key_to_jwks(pem, config.algorithm, "ati-proxy-1").ok())
1352 });
1353
1354 let state = Arc::new(ProxyState {
1355 registry,
1356 skill_registry,
1357 keyring,
1358 jwt_config,
1359 jwks_json,
1360 auth_cache: AuthCache::new(),
1361 });
1362
1363 let app = build_router(state);
1364
1365 let addr: SocketAddr = if let Some(ref bind) = bind_addr {
1366 format!("{bind}:{port}").parse()?
1367 } else {
1368 SocketAddr::from(([127, 0, 0, 1], port))
1369 };
1370
1371 tracing::info!(
1372 version = env!("CARGO_PKG_VERSION"),
1373 %addr,
1374 auth = auth_status,
1375 ati_dir = %ati_dir.display(),
1376 tools = tool_count,
1377 providers = provider_count,
1378 mcp = mcp_count,
1379 openapi = openapi_count,
1380 skills = skill_count,
1381 keyring = keyring_source,
1382 "ATI proxy server starting"
1383 );
1384 for (name, transport) in &mcp_providers {
1385 tracing::info!(provider = %name, transport = %transport, "MCP provider");
1386 }
1387 for name in &openapi_providers {
1388 tracing::info!(provider = %name, "OpenAPI provider");
1389 }
1390
1391 let listener = tokio::net::TcpListener::bind(addr).await?;
1392 axum::serve(listener, app).await?;
1393
1394 Ok(())
1395}
1396
1397fn write_proxy_audit(
1399 call_req: &CallRequest,
1400 agent_sub: &str,
1401 duration: std::time::Duration,
1402 error: Option<&str>,
1403) {
1404 let entry = crate::core::audit::AuditEntry {
1405 ts: chrono::Utc::now().to_rfc3339(),
1406 tool: call_req.tool_name.clone(),
1407 args: crate::core::audit::sanitize_args(&call_req.args),
1408 status: if error.is_some() {
1409 crate::core::audit::AuditStatus::Error
1410 } else {
1411 crate::core::audit::AuditStatus::Ok
1412 },
1413 duration_ms: duration.as_millis() as u64,
1414 agent_sub: agent_sub.to_string(),
1415 error: error.map(|s| s.to_string()),
1416 exit_code: None,
1417 };
1418 let _ = crate::core::audit::append(&entry);
1419}
1420
1421const HELP_SYSTEM_PROMPT: &str = r#"You are a helpful assistant for an AI agent that uses external tools via the `ati` CLI.
1424
1425## Available Tools
1426{tools}
1427
1428{skills_section}
1429
1430Answer the agent's question naturally, like a knowledgeable colleague would. Keep it short but useful:
1431
1432- Explain which tools to use and why, with `ati run` commands showing realistic parameter values
1433- If multiple steps are needed, walk through them briefly in order
1434- Mention important gotchas or parameter choices that matter
1435- If skills are relevant, suggest `ati skill show <name>` for the full methodology
1436
1437Keep your answer concise — a few short paragraphs with embedded code blocks. Only recommend tools from the list above."#;
1438
1439fn build_tool_context(
1440 tools: &[(
1441 &crate::core::manifest::Provider,
1442 &crate::core::manifest::Tool,
1443 )],
1444) -> String {
1445 let mut summaries = Vec::new();
1446 for (provider, tool) in tools {
1447 let mut summary = if let Some(cat) = &provider.category {
1448 format!(
1449 "- **{}** (provider: {}, category: {}): {}",
1450 tool.name, provider.name, cat, tool.description
1451 )
1452 } else {
1453 format!(
1454 "- **{}** (provider: {}): {}",
1455 tool.name, provider.name, tool.description
1456 )
1457 };
1458 if !tool.tags.is_empty() {
1459 summary.push_str(&format!("\n Tags: {}", tool.tags.join(", ")));
1460 }
1461 if provider.is_cli() && tool.input_schema.is_none() {
1463 let cmd = provider.cli_command.as_deref().unwrap_or("?");
1464 summary.push_str(&format!(
1465 "\n Usage: `ati run {} -- <args>` (passthrough to `{}`)",
1466 tool.name, cmd
1467 ));
1468 } else if let Some(schema) = &tool.input_schema {
1469 if let Some(props) = schema.get("properties") {
1470 if let Some(obj) = props.as_object() {
1471 let params: Vec<String> = obj
1472 .iter()
1473 .filter(|(_, v)| {
1474 v.get("x-ati-param-location").is_none()
1475 || v.get("description").is_some()
1476 })
1477 .map(|(k, v)| {
1478 let type_str =
1479 v.get("type").and_then(|t| t.as_str()).unwrap_or("string");
1480 let desc = v.get("description").and_then(|d| d.as_str()).unwrap_or("");
1481 format!(" --{k} ({type_str}): {desc}")
1482 })
1483 .collect();
1484 if !params.is_empty() {
1485 summary.push_str("\n Parameters:\n");
1486 summary.push_str(¶ms.join("\n"));
1487 }
1488 }
1489 }
1490 }
1491 summaries.push(summary);
1492 }
1493 summaries.join("\n\n")
1494}
1495
1496fn build_scoped_prompt(
1500 scope_name: &str,
1501 registry: &ManifestRegistry,
1502 skills_section: &str,
1503) -> Option<String> {
1504 if let Some((provider, tool)) = registry.get_tool(scope_name) {
1506 let mut details = format!(
1507 "**Name**: `{}`\n**Provider**: {} (handler: {})\n**Description**: {}\n",
1508 tool.name, provider.name, provider.handler, tool.description
1509 );
1510 if let Some(cat) = &provider.category {
1511 details.push_str(&format!("**Category**: {}\n", cat));
1512 }
1513 if provider.is_cli() {
1514 let cmd = provider.cli_command.as_deref().unwrap_or("?");
1515 details.push_str(&format!(
1516 "\n**Usage**: `ati run {} -- <args>` (passthrough to `{}`)\n",
1517 tool.name, cmd
1518 ));
1519 } else if let Some(schema) = &tool.input_schema {
1520 if let Some(props) = schema.get("properties").and_then(|p| p.as_object()) {
1521 let required: Vec<String> = schema
1522 .get("required")
1523 .and_then(|r| r.as_array())
1524 .map(|arr| {
1525 arr.iter()
1526 .filter_map(|v| v.as_str().map(|s| s.to_string()))
1527 .collect()
1528 })
1529 .unwrap_or_default();
1530 details.push_str("\n**Parameters**:\n");
1531 for (key, val) in props {
1532 let type_str = val.get("type").and_then(|t| t.as_str()).unwrap_or("string");
1533 let desc = val
1534 .get("description")
1535 .and_then(|d| d.as_str())
1536 .unwrap_or("");
1537 let req = if required.contains(key) {
1538 " **(required)**"
1539 } else {
1540 ""
1541 };
1542 details.push_str(&format!("- `--{key}` ({type_str}{req}): {desc}\n"));
1543 }
1544 }
1545 }
1546
1547 let prompt = format!(
1548 "You are an expert assistant for the `{}` tool, accessed via the `ati` CLI.\n\n\
1549 ## Tool Details\n{}\n\n{}\n\n\
1550 Answer the agent's question about this specific tool. Provide exact commands, explain flags and options, and give practical examples. Be concise and actionable.",
1551 tool.name, details, skills_section
1552 );
1553 return Some(prompt);
1554 }
1555
1556 if registry.has_provider(scope_name) {
1558 let tools = registry.tools_by_provider(scope_name);
1559 if tools.is_empty() {
1560 return None;
1561 }
1562 let tools_context = build_tool_context(&tools);
1563 let prompt = format!(
1564 "You are an expert assistant for the `{}` provider's tools, accessed via the `ati` CLI.\n\n\
1565 ## Tools in provider `{}`\n{}\n\n{}\n\n\
1566 Answer the agent's question about these tools. Provide exact `ati run` commands, explain parameters, and give practical examples. Be concise and actionable.",
1567 scope_name, scope_name, tools_context, skills_section
1568 );
1569 return Some(prompt);
1570 }
1571
1572 None
1573}