1use axum::{
8 body::Body,
9 extract::State,
10 http::{Request as HttpRequest, StatusCode},
11 middleware::{self, Next},
12 response::{IntoResponse, Response},
13 routing::{get, post},
14 Json, Router,
15};
16use serde::{Deserialize, Serialize};
17use serde_json::Value;
18use std::collections::HashMap;
19use std::net::SocketAddr;
20use std::path::PathBuf;
21use std::sync::Arc;
22
23use crate::core::auth_generator::{AuthCache, GenContext};
24use crate::core::http;
25use crate::core::jwt::{self, JwtConfig, TokenClaims};
26use crate::core::keyring::Keyring;
27use crate::core::manifest::ManifestRegistry;
28use crate::core::mcp_client;
29use crate::core::response;
30use crate::core::scope::ScopeConfig;
31use crate::core::skill::{self, SkillRegistry};
32use crate::core::xai;
33
34pub struct ProxyState {
36 pub registry: ManifestRegistry,
37 pub skill_registry: SkillRegistry,
38 pub keyring: Keyring,
39 pub jwt_config: Option<JwtConfig>,
41 pub jwks_json: Option<Value>,
43 pub auth_cache: AuthCache,
45}
46
47#[derive(Debug, Deserialize)]
50pub struct CallRequest {
51 pub tool_name: String,
52 #[serde(default = "default_args")]
56 pub args: Value,
57 #[serde(default)]
60 pub raw_args: Option<Vec<String>>,
61}
62
63fn default_args() -> Value {
64 Value::Object(serde_json::Map::new())
65}
66
67impl CallRequest {
68 fn args_as_map(&self) -> HashMap<String, Value> {
72 match &self.args {
73 Value::Object(map) => map.iter().map(|(k, v)| (k.clone(), v.clone())).collect(),
74 _ => HashMap::new(),
75 }
76 }
77
78 fn args_as_positional(&self) -> Vec<String> {
81 if let Some(ref raw) = self.raw_args {
83 return raw.clone();
84 }
85 match &self.args {
86 Value::Array(arr) => arr
88 .iter()
89 .map(|v| match v {
90 Value::String(s) => s.clone(),
91 other => other.to_string(),
92 })
93 .collect(),
94 Value::String(s) => s.split_whitespace().map(String::from).collect(),
96 Value::Object(map) => {
98 if let Some(Value::Array(pos)) = map.get("_positional") {
99 return pos
100 .iter()
101 .map(|v| match v {
102 Value::String(s) => s.clone(),
103 other => other.to_string(),
104 })
105 .collect();
106 }
107 let mut result = Vec::new();
109 for (k, v) in map {
110 result.push(format!("--{k}"));
111 match v {
112 Value::String(s) => result.push(s.clone()),
113 Value::Bool(true) => {} other => result.push(other.to_string()),
115 }
116 }
117 result
118 }
119 _ => Vec::new(),
120 }
121 }
122}
123
124#[derive(Debug, Serialize)]
125pub struct CallResponse {
126 pub result: Value,
127 #[serde(skip_serializing_if = "Option::is_none")]
128 pub error: Option<String>,
129}
130
131#[derive(Debug, Deserialize)]
132pub struct HelpRequest {
133 pub query: String,
134 #[serde(default)]
135 pub tool: Option<String>,
136}
137
138#[derive(Debug, Serialize)]
139pub struct HelpResponse {
140 pub content: String,
141 #[serde(skip_serializing_if = "Option::is_none")]
142 pub error: Option<String>,
143}
144
145#[derive(Debug, Serialize)]
146pub struct HealthResponse {
147 pub status: String,
148 pub version: String,
149 pub tools: usize,
150 pub providers: usize,
151 pub skills: usize,
152 pub auth: String,
153}
154
155#[derive(Debug, Deserialize)]
158pub struct SkillsQuery {
159 #[serde(default)]
160 pub category: Option<String>,
161 #[serde(default)]
162 pub provider: Option<String>,
163 #[serde(default)]
164 pub tool: Option<String>,
165 #[serde(default)]
166 pub search: Option<String>,
167}
168
169#[derive(Debug, Deserialize)]
170pub struct SkillDetailQuery {
171 #[serde(default)]
172 pub meta: Option<bool>,
173 #[serde(default)]
174 pub refs: Option<bool>,
175}
176
177#[derive(Debug, Deserialize)]
178pub struct SkillResolveRequest {
179 pub scopes: Vec<String>,
180 #[serde(default)]
182 pub include_content: bool,
183}
184
185#[derive(Debug, Deserialize)]
186pub struct SkillBundleBatchRequest {
187 pub names: Vec<String>,
188}
189
190async fn handle_call(
193 State(state): State<Arc<ProxyState>>,
194 req: HttpRequest<Body>,
195) -> impl IntoResponse {
196 let claims = req.extensions().get::<TokenClaims>().cloned();
198
199 let body_bytes = match axum::body::to_bytes(req.into_body(), 10 * 1024 * 1024).await {
201 Ok(b) => b,
202 Err(e) => {
203 return (
204 StatusCode::BAD_REQUEST,
205 Json(CallResponse {
206 result: Value::Null,
207 error: Some(format!("Failed to read request body: {e}")),
208 }),
209 );
210 }
211 };
212
213 let call_req: CallRequest = match serde_json::from_slice(&body_bytes) {
214 Ok(r) => r,
215 Err(e) => {
216 return (
217 StatusCode::UNPROCESSABLE_ENTITY,
218 Json(CallResponse {
219 result: Value::Null,
220 error: Some(format!("Invalid request: {e}")),
221 }),
222 );
223 }
224 };
225
226 tracing::debug!(
227 tool = %call_req.tool_name,
228 args = ?call_req.args,
229 "POST /call"
230 );
231
232 let (provider, tool) = match state.registry.get_tool(&call_req.tool_name) {
234 Some(pt) => pt,
235 None => {
236 return (
237 StatusCode::NOT_FOUND,
238 Json(CallResponse {
239 result: Value::Null,
240 error: Some(format!("Unknown tool: '{}'", call_req.tool_name)),
241 }),
242 );
243 }
244 };
245
246 if let Some(tool_scope) = &tool.scope {
248 let scopes = match &claims {
249 Some(c) => ScopeConfig::from_jwt(c),
250 None if state.jwt_config.is_none() => ScopeConfig::unrestricted(), None => {
252 return (
253 StatusCode::FORBIDDEN,
254 Json(CallResponse {
255 result: Value::Null,
256 error: Some("Authentication required — no JWT provided".into()),
257 }),
258 );
259 }
260 };
261
262 if let Err(e) = scopes.check_access(&call_req.tool_name, tool_scope) {
263 return (
264 StatusCode::FORBIDDEN,
265 Json(CallResponse {
266 result: Value::Null,
267 error: Some(format!("Access denied: {e}")),
268 }),
269 );
270 }
271 }
272
273 {
275 let scopes = match &claims {
276 Some(c) => ScopeConfig::from_jwt(c),
277 None => ScopeConfig::unrestricted(),
278 };
279 if let Some(ref rate_config) = scopes.rate_config {
280 if let Err(e) = crate::core::rate::check_and_record(&call_req.tool_name, rate_config) {
281 return (
282 StatusCode::TOO_MANY_REQUESTS,
283 Json(CallResponse {
284 result: Value::Null,
285 error: Some(format!("{e}")),
286 }),
287 );
288 }
289 }
290 }
291
292 let gen_ctx = GenContext {
294 jwt_sub: claims
295 .as_ref()
296 .map(|c| c.sub.clone())
297 .unwrap_or_else(|| "dev".into()),
298 jwt_scope: claims
299 .as_ref()
300 .map(|c| c.scope.clone())
301 .unwrap_or_else(|| "*".into()),
302 tool_name: call_req.tool_name.clone(),
303 timestamp: crate::core::jwt::now_secs(),
304 };
305
306 let agent_sub = claims.as_ref().map(|c| c.sub.clone()).unwrap_or_default();
308 let start = std::time::Instant::now();
309
310 let response = match provider.handler.as_str() {
311 "mcp" => {
312 let args_map = call_req.args_as_map();
313 match mcp_client::execute_with_gen(
314 provider,
315 &call_req.tool_name,
316 &args_map,
317 &state.keyring,
318 Some(&gen_ctx),
319 Some(&state.auth_cache),
320 )
321 .await
322 {
323 Ok(result) => (
324 StatusCode::OK,
325 Json(CallResponse {
326 result,
327 error: None,
328 }),
329 ),
330 Err(e) => (
331 StatusCode::BAD_GATEWAY,
332 Json(CallResponse {
333 result: Value::Null,
334 error: Some(format!("MCP error: {e}")),
335 }),
336 ),
337 }
338 }
339 "cli" => {
340 let positional = call_req.args_as_positional();
341 match crate::core::cli_executor::execute_with_gen(
342 provider,
343 &positional,
344 &state.keyring,
345 Some(&gen_ctx),
346 Some(&state.auth_cache),
347 )
348 .await
349 {
350 Ok(result) => (
351 StatusCode::OK,
352 Json(CallResponse {
353 result,
354 error: None,
355 }),
356 ),
357 Err(e) => (
358 StatusCode::BAD_GATEWAY,
359 Json(CallResponse {
360 result: Value::Null,
361 error: Some(format!("CLI error: {e}")),
362 }),
363 ),
364 }
365 }
366 _ => {
367 let args_map = call_req.args_as_map();
368 let raw_response = match match provider.handler.as_str() {
369 "xai" => xai::execute_xai_tool(provider, tool, &args_map, &state.keyring).await,
370 _ => {
371 http::execute_tool_with_gen(
372 provider,
373 tool,
374 &args_map,
375 &state.keyring,
376 Some(&gen_ctx),
377 Some(&state.auth_cache),
378 )
379 .await
380 }
381 } {
382 Ok(resp) => resp,
383 Err(e) => {
384 let duration = start.elapsed();
385 write_proxy_audit(&call_req, &agent_sub, duration, Some(&e.to_string()));
386 return (
387 StatusCode::BAD_GATEWAY,
388 Json(CallResponse {
389 result: Value::Null,
390 error: Some(format!("Upstream API error: {e}")),
391 }),
392 );
393 }
394 };
395
396 let processed = match response::process_response(&raw_response, tool.response.as_ref())
397 {
398 Ok(p) => p,
399 Err(e) => {
400 let duration = start.elapsed();
401 write_proxy_audit(&call_req, &agent_sub, duration, Some(&e.to_string()));
402 return (
403 StatusCode::INTERNAL_SERVER_ERROR,
404 Json(CallResponse {
405 result: raw_response,
406 error: Some(format!("Response processing error: {e}")),
407 }),
408 );
409 }
410 };
411
412 (
413 StatusCode::OK,
414 Json(CallResponse {
415 result: processed,
416 error: None,
417 }),
418 )
419 }
420 };
421
422 let duration = start.elapsed();
423 let error_msg = response.1.error.as_deref();
424 write_proxy_audit(&call_req, &agent_sub, duration, error_msg);
425
426 response
427}
428
429async fn handle_help(
430 State(state): State<Arc<ProxyState>>,
431 Json(req): Json<HelpRequest>,
432) -> impl IntoResponse {
433 tracing::debug!(query = %req.query, tool = ?req.tool, "POST /help");
434
435 let (llm_provider, llm_tool) = match state.registry.get_tool("_chat_completion") {
436 Some(pt) => pt,
437 None => {
438 return (
439 StatusCode::SERVICE_UNAVAILABLE,
440 Json(HelpResponse {
441 content: String::new(),
442 error: Some("No _llm.toml manifest found. Proxy help requires a configured LLM provider.".into()),
443 }),
444 );
445 }
446 };
447
448 let api_key = match llm_provider
449 .auth_key_name
450 .as_deref()
451 .and_then(|k| state.keyring.get(k))
452 {
453 Some(key) => key.to_string(),
454 None => {
455 return (
456 StatusCode::SERVICE_UNAVAILABLE,
457 Json(HelpResponse {
458 content: String::new(),
459 error: Some("LLM API key not found in keyring".into()),
460 }),
461 );
462 }
463 };
464
465 let scopes = ScopeConfig::unrestricted();
466 let resolved_skills = skill::resolve_skills(&state.skill_registry, &state.registry, &scopes);
467 let skills_section = if resolved_skills.is_empty() {
468 String::new()
469 } else {
470 format!(
471 "## Available Skills (methodology guides)\n{}",
472 skill::build_skill_context(&resolved_skills)
473 )
474 };
475
476 let system_prompt = if let Some(ref tool_name) = req.tool {
478 match build_scoped_prompt(tool_name, &state.registry, &skills_section) {
480 Some(prompt) => prompt,
481 None => {
482 tracing::debug!(scope = %tool_name, "scope not found, falling back to unscoped");
484 let all_tools = state.registry.list_public_tools();
485 let tools_context = build_tool_context(&all_tools);
486 HELP_SYSTEM_PROMPT
487 .replace("{tools}", &tools_context)
488 .replace("{skills_section}", &skills_section)
489 }
490 }
491 } else {
492 let all_tools = state.registry.list_public_tools();
493 let tools_context = build_tool_context(&all_tools);
494 HELP_SYSTEM_PROMPT
495 .replace("{tools}", &tools_context)
496 .replace("{skills_section}", &skills_section)
497 };
498
499 let request_body = serde_json::json!({
500 "model": "zai-glm-4.7",
501 "messages": [
502 {"role": "system", "content": system_prompt},
503 {"role": "user", "content": req.query}
504 ],
505 "max_completion_tokens": 1536,
506 "temperature": 0.3
507 });
508
509 let client = reqwest::Client::new();
510 let url = format!(
511 "{}{}",
512 llm_provider.base_url.trim_end_matches('/'),
513 llm_tool.endpoint
514 );
515
516 let response = match client
517 .post(&url)
518 .bearer_auth(&api_key)
519 .json(&request_body)
520 .send()
521 .await
522 {
523 Ok(r) => r,
524 Err(e) => {
525 return (
526 StatusCode::BAD_GATEWAY,
527 Json(HelpResponse {
528 content: String::new(),
529 error: Some(format!("LLM request failed: {e}")),
530 }),
531 );
532 }
533 };
534
535 if !response.status().is_success() {
536 let status = response.status();
537 let body = response.text().await.unwrap_or_default();
538 return (
539 StatusCode::BAD_GATEWAY,
540 Json(HelpResponse {
541 content: String::new(),
542 error: Some(format!("LLM API error ({status}): {body}")),
543 }),
544 );
545 }
546
547 let body: Value = match response.json().await {
548 Ok(b) => b,
549 Err(e) => {
550 return (
551 StatusCode::INTERNAL_SERVER_ERROR,
552 Json(HelpResponse {
553 content: String::new(),
554 error: Some(format!("Failed to parse LLM response: {e}")),
555 }),
556 );
557 }
558 };
559
560 let content = body
561 .pointer("/choices/0/message/content")
562 .and_then(|c| c.as_str())
563 .unwrap_or("No response from LLM")
564 .to_string();
565
566 (
567 StatusCode::OK,
568 Json(HelpResponse {
569 content,
570 error: None,
571 }),
572 )
573}
574
575async fn handle_health(State(state): State<Arc<ProxyState>>) -> impl IntoResponse {
576 let auth = if state.jwt_config.is_some() {
577 "jwt"
578 } else {
579 "disabled"
580 };
581
582 Json(HealthResponse {
583 status: "ok".into(),
584 version: env!("CARGO_PKG_VERSION").into(),
585 tools: state.registry.list_public_tools().len(),
586 providers: state.registry.list_providers().len(),
587 skills: state.skill_registry.skill_count(),
588 auth: auth.into(),
589 })
590}
591
592async fn handle_jwks(State(state): State<Arc<ProxyState>>) -> impl IntoResponse {
594 match &state.jwks_json {
595 Some(jwks) => (StatusCode::OK, Json(jwks.clone())),
596 None => (
597 StatusCode::NOT_FOUND,
598 Json(serde_json::json!({"error": "JWKS not configured"})),
599 ),
600 }
601}
602
603async fn handle_mcp(
608 State(state): State<Arc<ProxyState>>,
609 Json(msg): Json<Value>,
610) -> impl IntoResponse {
611 let method = msg.get("method").and_then(|m| m.as_str()).unwrap_or("");
612 let id = msg.get("id").cloned();
613
614 tracing::debug!(%method, "POST /mcp");
615
616 match method {
617 "initialize" => {
618 let result = serde_json::json!({
619 "protocolVersion": "2025-03-26",
620 "capabilities": {
621 "tools": { "listChanged": false }
622 },
623 "serverInfo": {
624 "name": "ati-proxy",
625 "version": env!("CARGO_PKG_VERSION")
626 }
627 });
628 jsonrpc_success(id, result)
629 }
630
631 "notifications/initialized" => (StatusCode::ACCEPTED, Json(Value::Null)),
632
633 "tools/list" => {
634 let all_tools = state.registry.list_public_tools();
635 let mcp_tools: Vec<Value> = all_tools
636 .iter()
637 .map(|(_provider, tool)| {
638 serde_json::json!({
639 "name": tool.name,
640 "description": tool.description,
641 "inputSchema": tool.input_schema.clone().unwrap_or(serde_json::json!({
642 "type": "object",
643 "properties": {}
644 }))
645 })
646 })
647 .collect();
648
649 let result = serde_json::json!({
650 "tools": mcp_tools,
651 });
652 jsonrpc_success(id, result)
653 }
654
655 "tools/call" => {
656 let params = msg.get("params").cloned().unwrap_or(Value::Null);
657 let tool_name = params.get("name").and_then(|n| n.as_str()).unwrap_or("");
658 let arguments: HashMap<String, Value> = params
659 .get("arguments")
660 .and_then(|a| serde_json::from_value(a.clone()).ok())
661 .unwrap_or_default();
662
663 if tool_name.is_empty() {
664 return jsonrpc_error(id, -32602, "Missing tool name in params.name");
665 }
666
667 let (provider, _tool) = match state.registry.get_tool(tool_name) {
668 Some(pt) => pt,
669 None => {
670 return jsonrpc_error(id, -32602, &format!("Unknown tool: '{tool_name}'"));
671 }
672 };
673
674 tracing::debug!(%tool_name, provider = %provider.name, "MCP tools/call");
675
676 let mcp_gen_ctx = GenContext {
677 jwt_sub: "dev".into(),
678 jwt_scope: "*".into(),
679 tool_name: tool_name.to_string(),
680 timestamp: crate::core::jwt::now_secs(),
681 };
682
683 let result = if provider.is_mcp() {
684 mcp_client::execute_with_gen(
685 provider,
686 tool_name,
687 &arguments,
688 &state.keyring,
689 Some(&mcp_gen_ctx),
690 Some(&state.auth_cache),
691 )
692 .await
693 } else if provider.is_cli() {
694 let raw: Vec<String> = arguments
696 .iter()
697 .flat_map(|(k, v)| {
698 let val = match v {
699 Value::String(s) => s.clone(),
700 other => other.to_string(),
701 };
702 vec![format!("--{k}"), val]
703 })
704 .collect();
705 crate::core::cli_executor::execute_with_gen(
706 provider,
707 &raw,
708 &state.keyring,
709 Some(&mcp_gen_ctx),
710 Some(&state.auth_cache),
711 )
712 .await
713 .map_err(|e| mcp_client::McpError::Transport(e.to_string()))
714 } else {
715 match match provider.handler.as_str() {
716 "xai" => {
717 xai::execute_xai_tool(provider, _tool, &arguments, &state.keyring).await
718 }
719 _ => {
720 http::execute_tool_with_gen(
721 provider,
722 _tool,
723 &arguments,
724 &state.keyring,
725 Some(&mcp_gen_ctx),
726 Some(&state.auth_cache),
727 )
728 .await
729 }
730 } {
731 Ok(val) => Ok(val),
732 Err(e) => Err(mcp_client::McpError::Transport(e.to_string())),
733 }
734 };
735
736 match result {
737 Ok(value) => {
738 let text = match &value {
739 Value::String(s) => s.clone(),
740 other => serde_json::to_string_pretty(other).unwrap_or_default(),
741 };
742 let mcp_result = serde_json::json!({
743 "content": [{"type": "text", "text": text}],
744 "isError": false,
745 });
746 jsonrpc_success(id, mcp_result)
747 }
748 Err(e) => {
749 let mcp_result = serde_json::json!({
750 "content": [{"type": "text", "text": format!("Error: {e}")}],
751 "isError": true,
752 });
753 jsonrpc_success(id, mcp_result)
754 }
755 }
756 }
757
758 _ => jsonrpc_error(id, -32601, &format!("Method not found: '{method}'")),
759 }
760}
761
762fn jsonrpc_success(id: Option<Value>, result: Value) -> (StatusCode, Json<Value>) {
763 (
764 StatusCode::OK,
765 Json(serde_json::json!({
766 "jsonrpc": "2.0",
767 "id": id,
768 "result": result,
769 })),
770 )
771}
772
773fn jsonrpc_error(id: Option<Value>, code: i64, message: &str) -> (StatusCode, Json<Value>) {
774 (
775 StatusCode::OK,
776 Json(serde_json::json!({
777 "jsonrpc": "2.0",
778 "id": id,
779 "error": {
780 "code": code,
781 "message": message,
782 }
783 })),
784 )
785}
786
787async fn handle_skills_list(
792 State(state): State<Arc<ProxyState>>,
793 axum::extract::Query(query): axum::extract::Query<SkillsQuery>,
794) -> impl IntoResponse {
795 tracing::debug!(
796 category = ?query.category,
797 provider = ?query.provider,
798 tool = ?query.tool,
799 search = ?query.search,
800 "GET /skills"
801 );
802
803 let skills: Vec<&skill::SkillMeta> = if let Some(search_query) = &query.search {
804 state.skill_registry.search(search_query)
805 } else if let Some(cat) = &query.category {
806 state.skill_registry.skills_for_category(cat)
807 } else if let Some(prov) = &query.provider {
808 state.skill_registry.skills_for_provider(prov)
809 } else if let Some(t) = &query.tool {
810 state.skill_registry.skills_for_tool(t)
811 } else {
812 state.skill_registry.list_skills().iter().collect()
813 };
814
815 let json: Vec<Value> = skills
816 .iter()
817 .map(|s| {
818 serde_json::json!({
819 "name": s.name,
820 "version": s.version,
821 "description": s.description,
822 "tools": s.tools,
823 "providers": s.providers,
824 "categories": s.categories,
825 "hint": s.hint,
826 })
827 })
828 .collect();
829
830 (StatusCode::OK, Json(Value::Array(json)))
831}
832
833async fn handle_skill_detail(
834 State(state): State<Arc<ProxyState>>,
835 axum::extract::Path(name): axum::extract::Path<String>,
836 axum::extract::Query(query): axum::extract::Query<SkillDetailQuery>,
837) -> impl IntoResponse {
838 tracing::debug!(%name, meta = ?query.meta, refs = ?query.refs, "GET /skills/:name");
839
840 let skill_meta = match state.skill_registry.get_skill(&name) {
841 Some(s) => s,
842 None => {
843 return (
844 StatusCode::NOT_FOUND,
845 Json(serde_json::json!({"error": format!("Skill '{name}' not found")})),
846 );
847 }
848 };
849
850 if query.meta.unwrap_or(false) {
851 return (
852 StatusCode::OK,
853 Json(serde_json::json!({
854 "name": skill_meta.name,
855 "version": skill_meta.version,
856 "description": skill_meta.description,
857 "author": skill_meta.author,
858 "tools": skill_meta.tools,
859 "providers": skill_meta.providers,
860 "categories": skill_meta.categories,
861 "keywords": skill_meta.keywords,
862 "hint": skill_meta.hint,
863 "depends_on": skill_meta.depends_on,
864 "suggests": skill_meta.suggests,
865 "license": skill_meta.license,
866 "compatibility": skill_meta.compatibility,
867 "allowed_tools": skill_meta.allowed_tools,
868 "format": skill_meta.format,
869 })),
870 );
871 }
872
873 let content = match state.skill_registry.read_content(&name) {
874 Ok(c) => c,
875 Err(e) => {
876 return (
877 StatusCode::INTERNAL_SERVER_ERROR,
878 Json(serde_json::json!({"error": format!("Failed to read skill: {e}")})),
879 );
880 }
881 };
882
883 let mut response = serde_json::json!({
884 "name": skill_meta.name,
885 "version": skill_meta.version,
886 "description": skill_meta.description,
887 "content": content,
888 });
889
890 if query.refs.unwrap_or(false) {
891 if let Ok(refs) = state.skill_registry.list_references(&name) {
892 response["references"] = serde_json::json!(refs);
893 }
894 }
895
896 (StatusCode::OK, Json(response))
897}
898
899async fn handle_skill_bundle(
903 State(state): State<Arc<ProxyState>>,
904 axum::extract::Path(name): axum::extract::Path<String>,
905) -> impl IntoResponse {
906 tracing::debug!(skill = %name, "GET /skills/:name/bundle");
907
908 let files = match state.skill_registry.bundle_files(&name) {
909 Ok(f) => f,
910 Err(_) => {
911 return (
912 StatusCode::NOT_FOUND,
913 Json(serde_json::json!({"error": format!("Skill '{name}' not found")})),
914 );
915 }
916 };
917
918 let mut file_map = serde_json::Map::new();
920 for (path, data) in &files {
921 match std::str::from_utf8(data) {
922 Ok(text) => {
923 file_map.insert(path.clone(), Value::String(text.to_string()));
924 }
925 Err(_) => {
926 use base64::Engine;
928 let encoded = base64::engine::general_purpose::STANDARD.encode(data);
929 file_map.insert(path.clone(), serde_json::json!({"base64": encoded}));
930 }
931 }
932 }
933
934 (
935 StatusCode::OK,
936 Json(serde_json::json!({
937 "name": name,
938 "files": file_map,
939 })),
940 )
941}
942
943async fn handle_skills_bundle_batch(
947 State(state): State<Arc<ProxyState>>,
948 Json(req): Json<SkillBundleBatchRequest>,
949) -> impl IntoResponse {
950 const MAX_BATCH: usize = 50;
951 if req.names.len() > MAX_BATCH {
952 return (
953 StatusCode::BAD_REQUEST,
954 Json(
955 serde_json::json!({"error": format!("batch size {} exceeds limit of {MAX_BATCH}", req.names.len())}),
956 ),
957 );
958 }
959
960 tracing::debug!(names = ?req.names, "POST /skills/bundle");
961
962 let mut result = serde_json::Map::new();
963 let mut missing: Vec<String> = Vec::new();
964
965 for name in &req.names {
966 let files = match state.skill_registry.bundle_files(name) {
967 Ok(f) => f,
968 Err(_) => {
969 missing.push(name.clone());
970 continue;
971 }
972 };
973
974 let mut file_map = serde_json::Map::new();
975 for (path, data) in &files {
976 match std::str::from_utf8(data) {
977 Ok(text) => {
978 file_map.insert(path.clone(), Value::String(text.to_string()));
979 }
980 Err(_) => {
981 use base64::Engine;
982 let encoded = base64::engine::general_purpose::STANDARD.encode(data);
983 file_map.insert(path.clone(), serde_json::json!({"base64": encoded}));
984 }
985 }
986 }
987
988 result.insert(name.clone(), serde_json::json!({ "files": file_map }));
989 }
990
991 (
992 StatusCode::OK,
993 Json(serde_json::json!({ "skills": result, "missing": missing })),
994 )
995}
996
997async fn handle_skills_resolve(
998 State(state): State<Arc<ProxyState>>,
999 Json(req): Json<SkillResolveRequest>,
1000) -> impl IntoResponse {
1001 tracing::debug!(scopes = ?req.scopes, include_content = req.include_content, "POST /skills/resolve");
1002
1003 let include_content = req.include_content;
1004 let scopes = ScopeConfig {
1005 scopes: req.scopes,
1006 sub: String::new(),
1007 expires_at: 0,
1008 rate_config: None,
1009 };
1010
1011 let resolved = skill::resolve_skills(&state.skill_registry, &state.registry, &scopes);
1012
1013 let json: Vec<Value> = resolved
1014 .iter()
1015 .map(|s| {
1016 let mut entry = serde_json::json!({
1017 "name": s.name,
1018 "version": s.version,
1019 "description": s.description,
1020 "tools": s.tools,
1021 "providers": s.providers,
1022 "categories": s.categories,
1023 });
1024 if include_content {
1025 if let Ok(content) = state.skill_registry.read_content(&s.name) {
1026 entry["content"] = Value::String(content);
1027 }
1028 }
1029 entry
1030 })
1031 .collect();
1032
1033 (StatusCode::OK, Json(Value::Array(json)))
1034}
1035
1036async fn auth_middleware(
1044 State(state): State<Arc<ProxyState>>,
1045 mut req: HttpRequest<Body>,
1046 next: Next,
1047) -> Result<Response, StatusCode> {
1048 let path = req.uri().path();
1049
1050 if path == "/health" || path == "/.well-known/jwks.json" {
1052 return Ok(next.run(req).await);
1053 }
1054
1055 let jwt_config = match &state.jwt_config {
1057 Some(c) => c,
1058 None => return Ok(next.run(req).await),
1059 };
1060
1061 let auth_header = req
1063 .headers()
1064 .get("authorization")
1065 .and_then(|v| v.to_str().ok());
1066
1067 let token = match auth_header {
1068 Some(header) if header.starts_with("Bearer ") => &header[7..],
1069 _ => return Err(StatusCode::UNAUTHORIZED),
1070 };
1071
1072 match jwt::validate(token, jwt_config) {
1074 Ok(claims) => {
1075 tracing::debug!(sub = %claims.sub, scopes = %claims.scope, "JWT validated");
1076 req.extensions_mut().insert(claims);
1077 Ok(next.run(req).await)
1078 }
1079 Err(e) => {
1080 tracing::debug!(error = %e, "JWT validation failed");
1081 Err(StatusCode::UNAUTHORIZED)
1082 }
1083 }
1084}
1085
1086pub fn build_router(state: Arc<ProxyState>) -> Router {
1090 Router::new()
1091 .route("/call", post(handle_call))
1092 .route("/help", post(handle_help))
1093 .route("/mcp", post(handle_mcp))
1094 .route("/skills", get(handle_skills_list))
1095 .route("/skills/resolve", post(handle_skills_resolve))
1096 .route("/skills/bundle", post(handle_skills_bundle_batch))
1097 .route("/skills/{name}", get(handle_skill_detail))
1098 .route("/skills/{name}/bundle", get(handle_skill_bundle))
1099 .route("/health", get(handle_health))
1100 .route("/.well-known/jwks.json", get(handle_jwks))
1101 .layer(middleware::from_fn_with_state(
1102 state.clone(),
1103 auth_middleware,
1104 ))
1105 .with_state(state)
1106}
1107
1108pub async fn run(
1112 port: u16,
1113 bind_addr: Option<String>,
1114 ati_dir: PathBuf,
1115 _verbose: bool,
1116 env_keys: bool,
1117) -> Result<(), Box<dyn std::error::Error>> {
1118 let manifests_dir = ati_dir.join("manifests");
1120 let registry = ManifestRegistry::load(&manifests_dir)?;
1121
1122 let tool_count = registry.list_public_tools().len();
1123 let provider_count = registry.list_providers().len();
1124
1125 let keyring_source;
1127 let keyring = if env_keys {
1128 let kr = Keyring::from_env();
1130 let key_names = kr.key_names();
1131 tracing::info!(
1132 count = key_names.len(),
1133 "loaded API keys from ATI_KEY_* env vars"
1134 );
1135 for name in &key_names {
1136 tracing::debug!(key = %name, "env key loaded");
1137 }
1138 keyring_source = "env-vars (ATI_KEY_*)";
1139 kr
1140 } else {
1141 let keyring_path = ati_dir.join("keyring.enc");
1143 if keyring_path.exists() {
1144 if let Ok(kr) = Keyring::load(&keyring_path) {
1145 keyring_source = "keyring.enc (sealed key)";
1146 kr
1147 } else if let Ok(kr) = Keyring::load_local(&keyring_path, &ati_dir) {
1148 keyring_source = "keyring.enc (persistent key)";
1149 kr
1150 } else {
1151 tracing::warn!("keyring.enc exists but could not be decrypted");
1152 keyring_source = "empty (decryption failed)";
1153 Keyring::empty()
1154 }
1155 } else {
1156 let creds_path = ati_dir.join("credentials");
1157 if creds_path.exists() {
1158 match Keyring::load_credentials(&creds_path) {
1159 Ok(kr) => {
1160 keyring_source = "credentials (plaintext)";
1161 kr
1162 }
1163 Err(e) => {
1164 tracing::warn!(error = %e, "failed to load credentials");
1165 keyring_source = "empty (credentials error)";
1166 Keyring::empty()
1167 }
1168 }
1169 } else {
1170 tracing::warn!("no keyring.enc or credentials found — running without API keys");
1171 tracing::warn!("tools requiring authentication will fail");
1172 keyring_source = "empty (no auth)";
1173 Keyring::empty()
1174 }
1175 }
1176 };
1177
1178 let mcp_providers: Vec<(String, String)> = registry
1180 .list_mcp_providers()
1181 .iter()
1182 .map(|p| (p.name.clone(), p.mcp_transport_type().to_string()))
1183 .collect();
1184 let mcp_count = mcp_providers.len();
1185 let openapi_providers: Vec<String> = registry
1186 .list_openapi_providers()
1187 .iter()
1188 .map(|p| p.name.clone())
1189 .collect();
1190 let openapi_count = openapi_providers.len();
1191
1192 let skills_dir = ati_dir.join("skills");
1194 let mut skill_registry = SkillRegistry::load(&skills_dir).unwrap_or_else(|e| {
1195 tracing::warn!(error = %e, "failed to load skills");
1196 SkillRegistry::load(std::path::Path::new("/nonexistent-fallback")).unwrap()
1197 });
1198
1199 if let Ok(registry_url) = std::env::var("ATI_SKILL_REGISTRY") {
1201 if let Some(bucket) = registry_url.strip_prefix("gcs://") {
1202 let cred_key = "gcp_credentials";
1203 if let Some(cred_json) = keyring.get(cred_key) {
1204 match crate::core::gcs::GcsClient::new(bucket.to_string(), cred_json) {
1205 Ok(client) => match crate::core::gcs::GcsSkillSource::load(&client).await {
1206 Ok(gcs_source) => {
1207 let gcs_count = gcs_source.skill_count();
1208 skill_registry.merge(gcs_source);
1209 tracing::info!(
1210 bucket = %bucket,
1211 skills = gcs_count,
1212 "loaded skills from GCS registry"
1213 );
1214 }
1215 Err(e) => {
1216 tracing::warn!(error = %e, bucket = %bucket, "failed to load GCS skills");
1217 }
1218 },
1219 Err(e) => {
1220 tracing::warn!(error = %e, "failed to init GCS client");
1221 }
1222 }
1223 } else {
1224 tracing::warn!(
1225 key = %cred_key,
1226 "ATI_SKILL_REGISTRY set but GCS credentials not found in keyring"
1227 );
1228 }
1229 } else {
1230 tracing::warn!(
1231 url = %registry_url,
1232 "unsupported skill registry scheme (only gcs:// is supported)"
1233 );
1234 }
1235 }
1236
1237 let skill_count = skill_registry.skill_count();
1238
1239 let jwt_config = match jwt::config_from_env() {
1241 Ok(config) => config,
1242 Err(e) => {
1243 tracing::warn!(error = %e, "JWT config error");
1244 None
1245 }
1246 };
1247
1248 let auth_status = if jwt_config.is_some() {
1249 "JWT enabled"
1250 } else {
1251 "DISABLED (no JWT keys configured)"
1252 };
1253
1254 let jwks_json = jwt_config.as_ref().and_then(|config| {
1256 config
1257 .public_key_pem
1258 .as_ref()
1259 .and_then(|pem| jwt::public_key_to_jwks(pem, config.algorithm, "ati-proxy-1").ok())
1260 });
1261
1262 let state = Arc::new(ProxyState {
1263 registry,
1264 skill_registry,
1265 keyring,
1266 jwt_config,
1267 jwks_json,
1268 auth_cache: AuthCache::new(),
1269 });
1270
1271 let app = build_router(state);
1272
1273 let addr: SocketAddr = if let Some(ref bind) = bind_addr {
1274 format!("{bind}:{port}").parse()?
1275 } else {
1276 SocketAddr::from(([127, 0, 0, 1], port))
1277 };
1278
1279 tracing::info!(
1280 version = env!("CARGO_PKG_VERSION"),
1281 %addr,
1282 auth = auth_status,
1283 ati_dir = %ati_dir.display(),
1284 tools = tool_count,
1285 providers = provider_count,
1286 mcp = mcp_count,
1287 openapi = openapi_count,
1288 skills = skill_count,
1289 keyring = keyring_source,
1290 "ATI proxy server starting"
1291 );
1292 for (name, transport) in &mcp_providers {
1293 tracing::info!(provider = %name, transport = %transport, "MCP provider");
1294 }
1295 for name in &openapi_providers {
1296 tracing::info!(provider = %name, "OpenAPI provider");
1297 }
1298
1299 let listener = tokio::net::TcpListener::bind(addr).await?;
1300 axum::serve(listener, app).await?;
1301
1302 Ok(())
1303}
1304
1305fn write_proxy_audit(
1307 call_req: &CallRequest,
1308 agent_sub: &str,
1309 duration: std::time::Duration,
1310 error: Option<&str>,
1311) {
1312 let entry = crate::core::audit::AuditEntry {
1313 ts: chrono::Utc::now().to_rfc3339(),
1314 tool: call_req.tool_name.clone(),
1315 args: crate::core::audit::sanitize_args(&call_req.args),
1316 status: if error.is_some() {
1317 crate::core::audit::AuditStatus::Error
1318 } else {
1319 crate::core::audit::AuditStatus::Ok
1320 },
1321 duration_ms: duration.as_millis() as u64,
1322 agent_sub: agent_sub.to_string(),
1323 error: error.map(|s| s.to_string()),
1324 exit_code: None,
1325 };
1326 let _ = crate::core::audit::append(&entry);
1327}
1328
1329const HELP_SYSTEM_PROMPT: &str = r#"You are a helpful assistant for an AI agent that uses external tools via the `ati` CLI.
1332
1333## Available Tools
1334{tools}
1335
1336{skills_section}
1337
1338Answer the agent's question naturally, like a knowledgeable colleague would. Keep it short but useful:
1339
1340- Explain which tools to use and why, with `ati run` commands showing realistic parameter values
1341- If multiple steps are needed, walk through them briefly in order
1342- Mention important gotchas or parameter choices that matter
1343- If skills are relevant, suggest `ati skill show <name>` for the full methodology
1344
1345Keep your answer concise — a few short paragraphs with embedded code blocks. Only recommend tools from the list above."#;
1346
1347fn build_tool_context(
1348 tools: &[(
1349 &crate::core::manifest::Provider,
1350 &crate::core::manifest::Tool,
1351 )],
1352) -> String {
1353 let mut summaries = Vec::new();
1354 for (provider, tool) in tools {
1355 let mut summary = if let Some(cat) = &provider.category {
1356 format!(
1357 "- **{}** (provider: {}, category: {}): {}",
1358 tool.name, provider.name, cat, tool.description
1359 )
1360 } else {
1361 format!(
1362 "- **{}** (provider: {}): {}",
1363 tool.name, provider.name, tool.description
1364 )
1365 };
1366 if !tool.tags.is_empty() {
1367 summary.push_str(&format!("\n Tags: {}", tool.tags.join(", ")));
1368 }
1369 if provider.is_cli() && tool.input_schema.is_none() {
1371 let cmd = provider.cli_command.as_deref().unwrap_or("?");
1372 summary.push_str(&format!(
1373 "\n Usage: `ati run {} -- <args>` (passthrough to `{}`)",
1374 tool.name, cmd
1375 ));
1376 } else if let Some(schema) = &tool.input_schema {
1377 if let Some(props) = schema.get("properties") {
1378 if let Some(obj) = props.as_object() {
1379 let params: Vec<String> = obj
1380 .iter()
1381 .filter(|(_, v)| {
1382 v.get("x-ati-param-location").is_none()
1383 || v.get("description").is_some()
1384 })
1385 .map(|(k, v)| {
1386 let type_str =
1387 v.get("type").and_then(|t| t.as_str()).unwrap_or("string");
1388 let desc = v.get("description").and_then(|d| d.as_str()).unwrap_or("");
1389 format!(" --{k} ({type_str}): {desc}")
1390 })
1391 .collect();
1392 if !params.is_empty() {
1393 summary.push_str("\n Parameters:\n");
1394 summary.push_str(¶ms.join("\n"));
1395 }
1396 }
1397 }
1398 }
1399 summaries.push(summary);
1400 }
1401 summaries.join("\n\n")
1402}
1403
1404fn build_scoped_prompt(
1408 scope_name: &str,
1409 registry: &ManifestRegistry,
1410 skills_section: &str,
1411) -> Option<String> {
1412 if let Some((provider, tool)) = registry.get_tool(scope_name) {
1414 let mut details = format!(
1415 "**Name**: `{}`\n**Provider**: {} (handler: {})\n**Description**: {}\n",
1416 tool.name, provider.name, provider.handler, tool.description
1417 );
1418 if let Some(cat) = &provider.category {
1419 details.push_str(&format!("**Category**: {}\n", cat));
1420 }
1421 if provider.is_cli() {
1422 let cmd = provider.cli_command.as_deref().unwrap_or("?");
1423 details.push_str(&format!(
1424 "\n**Usage**: `ati run {} -- <args>` (passthrough to `{}`)\n",
1425 tool.name, cmd
1426 ));
1427 } else if let Some(schema) = &tool.input_schema {
1428 if let Some(props) = schema.get("properties").and_then(|p| p.as_object()) {
1429 let required: Vec<String> = schema
1430 .get("required")
1431 .and_then(|r| r.as_array())
1432 .map(|arr| {
1433 arr.iter()
1434 .filter_map(|v| v.as_str().map(|s| s.to_string()))
1435 .collect()
1436 })
1437 .unwrap_or_default();
1438 details.push_str("\n**Parameters**:\n");
1439 for (key, val) in props {
1440 let type_str = val.get("type").and_then(|t| t.as_str()).unwrap_or("string");
1441 let desc = val
1442 .get("description")
1443 .and_then(|d| d.as_str())
1444 .unwrap_or("");
1445 let req = if required.contains(key) {
1446 " **(required)**"
1447 } else {
1448 ""
1449 };
1450 details.push_str(&format!("- `--{key}` ({type_str}{req}): {desc}\n"));
1451 }
1452 }
1453 }
1454
1455 let prompt = format!(
1456 "You are an expert assistant for the `{}` tool, accessed via the `ati` CLI.\n\n\
1457 ## Tool Details\n{}\n\n{}\n\n\
1458 Answer the agent's question about this specific tool. Provide exact commands, explain flags and options, and give practical examples. Be concise and actionable.",
1459 tool.name, details, skills_section
1460 );
1461 return Some(prompt);
1462 }
1463
1464 if registry.has_provider(scope_name) {
1466 let tools = registry.tools_by_provider(scope_name);
1467 if tools.is_empty() {
1468 return None;
1469 }
1470 let tools_context = build_tool_context(&tools);
1471 let prompt = format!(
1472 "You are an expert assistant for the `{}` provider's tools, accessed via the `ati` CLI.\n\n\
1473 ## Tools in provider `{}`\n{}\n\n{}\n\n\
1474 Answer the agent's question about these tools. Provide exact `ati run` commands, explain parameters, and give practical examples. Be concise and actionable.",
1475 scope_name, scope_name, tools_context, skills_section
1476 );
1477 return Some(prompt);
1478 }
1479
1480 None
1481}