1use axum::{
8 body::Body,
9 extract::State,
10 http::{Request as HttpRequest, StatusCode},
11 middleware::{self, Next},
12 response::{IntoResponse, Response},
13 routing::{get, post},
14 Json, Router,
15};
16use serde::{Deserialize, Serialize};
17use serde_json::Value;
18use std::collections::HashMap;
19use std::net::SocketAddr;
20use std::path::PathBuf;
21use std::sync::Arc;
22
23use crate::core::auth_generator::{AuthCache, GenContext};
24use crate::core::http;
25use crate::core::jwt::{self, JwtConfig, TokenClaims};
26use crate::core::keyring::Keyring;
27use crate::core::manifest::ManifestRegistry;
28use crate::core::mcp_client;
29use crate::core::response;
30use crate::core::scope::ScopeConfig;
31use crate::core::skill::{self, SkillRegistry};
32use crate::core::xai;
33
34pub struct ProxyState {
36 pub registry: ManifestRegistry,
37 pub skill_registry: SkillRegistry,
38 pub keyring: Keyring,
39 pub jwt_config: Option<JwtConfig>,
41 pub jwks_json: Option<Value>,
43 pub auth_cache: AuthCache,
45}
46
47#[derive(Debug, Deserialize)]
50pub struct CallRequest {
51 pub tool_name: String,
52 #[serde(default = "default_args")]
56 pub args: Value,
57 #[serde(default)]
60 pub raw_args: Option<Vec<String>>,
61}
62
63fn default_args() -> Value {
64 Value::Object(serde_json::Map::new())
65}
66
67impl CallRequest {
68 fn args_as_map(&self) -> HashMap<String, Value> {
72 match &self.args {
73 Value::Object(map) => map.iter().map(|(k, v)| (k.clone(), v.clone())).collect(),
74 _ => HashMap::new(),
75 }
76 }
77
78 fn args_as_positional(&self) -> Vec<String> {
81 if let Some(ref raw) = self.raw_args {
83 return raw.clone();
84 }
85 match &self.args {
86 Value::Array(arr) => arr
88 .iter()
89 .map(|v| match v {
90 Value::String(s) => s.clone(),
91 other => other.to_string(),
92 })
93 .collect(),
94 Value::String(s) => s.split_whitespace().map(String::from).collect(),
96 Value::Object(map) => {
98 if let Some(Value::Array(pos)) = map.get("_positional") {
99 return pos
100 .iter()
101 .map(|v| match v {
102 Value::String(s) => s.clone(),
103 other => other.to_string(),
104 })
105 .collect();
106 }
107 let mut result = Vec::new();
109 for (k, v) in map {
110 result.push(format!("--{k}"));
111 match v {
112 Value::String(s) => result.push(s.clone()),
113 Value::Bool(true) => {} other => result.push(other.to_string()),
115 }
116 }
117 result
118 }
119 _ => Vec::new(),
120 }
121 }
122}
123
124#[derive(Debug, Serialize)]
125pub struct CallResponse {
126 pub result: Value,
127 #[serde(skip_serializing_if = "Option::is_none")]
128 pub error: Option<String>,
129}
130
131#[derive(Debug, Deserialize)]
132pub struct HelpRequest {
133 pub query: String,
134 #[serde(default)]
135 pub tool: Option<String>,
136}
137
138#[derive(Debug, Serialize)]
139pub struct HelpResponse {
140 pub content: String,
141 #[serde(skip_serializing_if = "Option::is_none")]
142 pub error: Option<String>,
143}
144
145#[derive(Debug, Serialize)]
146pub struct HealthResponse {
147 pub status: String,
148 pub version: String,
149 pub tools: usize,
150 pub providers: usize,
151 pub skills: usize,
152 pub auth: String,
153}
154
155#[derive(Debug, Deserialize)]
158pub struct SkillsQuery {
159 #[serde(default)]
160 pub category: Option<String>,
161 #[serde(default)]
162 pub provider: Option<String>,
163 #[serde(default)]
164 pub tool: Option<String>,
165 #[serde(default)]
166 pub search: Option<String>,
167}
168
169#[derive(Debug, Deserialize)]
170pub struct SkillDetailQuery {
171 #[serde(default)]
172 pub meta: Option<bool>,
173 #[serde(default)]
174 pub refs: Option<bool>,
175}
176
177#[derive(Debug, Deserialize)]
178pub struct SkillResolveRequest {
179 pub scopes: Vec<String>,
180}
181
182async fn handle_call(
185 State(state): State<Arc<ProxyState>>,
186 req: HttpRequest<Body>,
187) -> impl IntoResponse {
188 let claims = req.extensions().get::<TokenClaims>().cloned();
190
191 let body_bytes = match axum::body::to_bytes(req.into_body(), 10 * 1024 * 1024).await {
193 Ok(b) => b,
194 Err(e) => {
195 return (
196 StatusCode::BAD_REQUEST,
197 Json(CallResponse {
198 result: Value::Null,
199 error: Some(format!("Failed to read request body: {e}")),
200 }),
201 );
202 }
203 };
204
205 let call_req: CallRequest = match serde_json::from_slice(&body_bytes) {
206 Ok(r) => r,
207 Err(e) => {
208 return (
209 StatusCode::UNPROCESSABLE_ENTITY,
210 Json(CallResponse {
211 result: Value::Null,
212 error: Some(format!("Invalid request: {e}")),
213 }),
214 );
215 }
216 };
217
218 tracing::debug!(
219 tool = %call_req.tool_name,
220 args = ?call_req.args,
221 "POST /call"
222 );
223
224 let (provider, tool) = match state.registry.get_tool(&call_req.tool_name) {
226 Some(pt) => pt,
227 None => {
228 return (
229 StatusCode::NOT_FOUND,
230 Json(CallResponse {
231 result: Value::Null,
232 error: Some(format!("Unknown tool: '{}'", call_req.tool_name)),
233 }),
234 );
235 }
236 };
237
238 if let Some(tool_scope) = &tool.scope {
240 let scopes = match &claims {
241 Some(c) => ScopeConfig::from_jwt(c),
242 None if state.jwt_config.is_none() => ScopeConfig::unrestricted(), None => {
244 return (
245 StatusCode::FORBIDDEN,
246 Json(CallResponse {
247 result: Value::Null,
248 error: Some("Authentication required — no JWT provided".into()),
249 }),
250 );
251 }
252 };
253
254 if let Err(e) = scopes.check_access(&call_req.tool_name, tool_scope) {
255 return (
256 StatusCode::FORBIDDEN,
257 Json(CallResponse {
258 result: Value::Null,
259 error: Some(format!("Access denied: {e}")),
260 }),
261 );
262 }
263 }
264
265 {
267 let scopes = match &claims {
268 Some(c) => ScopeConfig::from_jwt(c),
269 None => ScopeConfig::unrestricted(),
270 };
271 if let Some(ref rate_config) = scopes.rate_config {
272 if let Err(e) = crate::core::rate::check_and_record(&call_req.tool_name, rate_config) {
273 return (
274 StatusCode::TOO_MANY_REQUESTS,
275 Json(CallResponse {
276 result: Value::Null,
277 error: Some(format!("{e}")),
278 }),
279 );
280 }
281 }
282 }
283
284 let gen_ctx = GenContext {
286 jwt_sub: claims
287 .as_ref()
288 .map(|c| c.sub.clone())
289 .unwrap_or_else(|| "dev".into()),
290 jwt_scope: claims
291 .as_ref()
292 .map(|c| c.scope.clone())
293 .unwrap_or_else(|| "*".into()),
294 tool_name: call_req.tool_name.clone(),
295 timestamp: crate::core::jwt::now_secs(),
296 };
297
298 let agent_sub = claims.as_ref().map(|c| c.sub.clone()).unwrap_or_default();
300 let start = std::time::Instant::now();
301
302 let response = match provider.handler.as_str() {
303 "mcp" => {
304 let args_map = call_req.args_as_map();
305 match mcp_client::execute_with_gen(
306 provider,
307 &call_req.tool_name,
308 &args_map,
309 &state.keyring,
310 Some(&gen_ctx),
311 Some(&state.auth_cache),
312 )
313 .await
314 {
315 Ok(result) => (
316 StatusCode::OK,
317 Json(CallResponse {
318 result,
319 error: None,
320 }),
321 ),
322 Err(e) => (
323 StatusCode::BAD_GATEWAY,
324 Json(CallResponse {
325 result: Value::Null,
326 error: Some(format!("MCP error: {e}")),
327 }),
328 ),
329 }
330 }
331 "cli" => {
332 let positional = call_req.args_as_positional();
333 match crate::core::cli_executor::execute_with_gen(
334 provider,
335 &positional,
336 &state.keyring,
337 Some(&gen_ctx),
338 Some(&state.auth_cache),
339 )
340 .await
341 {
342 Ok(result) => (
343 StatusCode::OK,
344 Json(CallResponse {
345 result,
346 error: None,
347 }),
348 ),
349 Err(e) => (
350 StatusCode::BAD_GATEWAY,
351 Json(CallResponse {
352 result: Value::Null,
353 error: Some(format!("CLI error: {e}")),
354 }),
355 ),
356 }
357 }
358 _ => {
359 let args_map = call_req.args_as_map();
360 let raw_response = match match provider.handler.as_str() {
361 "xai" => xai::execute_xai_tool(provider, tool, &args_map, &state.keyring).await,
362 _ => {
363 http::execute_tool_with_gen(
364 provider,
365 tool,
366 &args_map,
367 &state.keyring,
368 Some(&gen_ctx),
369 Some(&state.auth_cache),
370 )
371 .await
372 }
373 } {
374 Ok(resp) => resp,
375 Err(e) => {
376 let duration = start.elapsed();
377 write_proxy_audit(&call_req, &agent_sub, duration, Some(&e.to_string()));
378 return (
379 StatusCode::BAD_GATEWAY,
380 Json(CallResponse {
381 result: Value::Null,
382 error: Some(format!("Upstream API error: {e}")),
383 }),
384 );
385 }
386 };
387
388 let processed = match response::process_response(&raw_response, tool.response.as_ref())
389 {
390 Ok(p) => p,
391 Err(e) => {
392 let duration = start.elapsed();
393 write_proxy_audit(&call_req, &agent_sub, duration, Some(&e.to_string()));
394 return (
395 StatusCode::INTERNAL_SERVER_ERROR,
396 Json(CallResponse {
397 result: raw_response,
398 error: Some(format!("Response processing error: {e}")),
399 }),
400 );
401 }
402 };
403
404 (
405 StatusCode::OK,
406 Json(CallResponse {
407 result: processed,
408 error: None,
409 }),
410 )
411 }
412 };
413
414 let duration = start.elapsed();
415 let error_msg = response.1.error.as_deref();
416 write_proxy_audit(&call_req, &agent_sub, duration, error_msg);
417
418 response
419}
420
421async fn handle_help(
422 State(state): State<Arc<ProxyState>>,
423 Json(req): Json<HelpRequest>,
424) -> impl IntoResponse {
425 tracing::debug!(query = %req.query, tool = ?req.tool, "POST /help");
426
427 let (llm_provider, llm_tool) = match state.registry.get_tool("_chat_completion") {
428 Some(pt) => pt,
429 None => {
430 return (
431 StatusCode::SERVICE_UNAVAILABLE,
432 Json(HelpResponse {
433 content: String::new(),
434 error: Some("No _llm.toml manifest found. Proxy help requires a configured LLM provider.".into()),
435 }),
436 );
437 }
438 };
439
440 let api_key = match llm_provider
441 .auth_key_name
442 .as_deref()
443 .and_then(|k| state.keyring.get(k))
444 {
445 Some(key) => key.to_string(),
446 None => {
447 return (
448 StatusCode::SERVICE_UNAVAILABLE,
449 Json(HelpResponse {
450 content: String::new(),
451 error: Some("LLM API key not found in keyring".into()),
452 }),
453 );
454 }
455 };
456
457 let scopes = ScopeConfig::unrestricted();
458 let resolved_skills = skill::resolve_skills(&state.skill_registry, &state.registry, &scopes);
459 let skills_section = if resolved_skills.is_empty() {
460 String::new()
461 } else {
462 format!(
463 "## Available Skills (methodology guides)\n{}",
464 skill::build_skill_context(&resolved_skills)
465 )
466 };
467
468 let system_prompt = if let Some(ref tool_name) = req.tool {
470 match build_scoped_prompt(tool_name, &state.registry, &skills_section) {
472 Some(prompt) => prompt,
473 None => {
474 tracing::debug!(scope = %tool_name, "scope not found, falling back to unscoped");
476 let all_tools = state.registry.list_public_tools();
477 let tools_context = build_tool_context(&all_tools);
478 HELP_SYSTEM_PROMPT
479 .replace("{tools}", &tools_context)
480 .replace("{skills_section}", &skills_section)
481 }
482 }
483 } else {
484 let all_tools = state.registry.list_public_tools();
485 let tools_context = build_tool_context(&all_tools);
486 HELP_SYSTEM_PROMPT
487 .replace("{tools}", &tools_context)
488 .replace("{skills_section}", &skills_section)
489 };
490
491 let request_body = serde_json::json!({
492 "model": "zai-glm-4.7",
493 "messages": [
494 {"role": "system", "content": system_prompt},
495 {"role": "user", "content": req.query}
496 ],
497 "max_completion_tokens": 1536,
498 "temperature": 0.3
499 });
500
501 let client = reqwest::Client::new();
502 let url = format!(
503 "{}{}",
504 llm_provider.base_url.trim_end_matches('/'),
505 llm_tool.endpoint
506 );
507
508 let response = match client
509 .post(&url)
510 .bearer_auth(&api_key)
511 .json(&request_body)
512 .send()
513 .await
514 {
515 Ok(r) => r,
516 Err(e) => {
517 return (
518 StatusCode::BAD_GATEWAY,
519 Json(HelpResponse {
520 content: String::new(),
521 error: Some(format!("LLM request failed: {e}")),
522 }),
523 );
524 }
525 };
526
527 if !response.status().is_success() {
528 let status = response.status();
529 let body = response.text().await.unwrap_or_default();
530 return (
531 StatusCode::BAD_GATEWAY,
532 Json(HelpResponse {
533 content: String::new(),
534 error: Some(format!("LLM API error ({status}): {body}")),
535 }),
536 );
537 }
538
539 let body: Value = match response.json().await {
540 Ok(b) => b,
541 Err(e) => {
542 return (
543 StatusCode::INTERNAL_SERVER_ERROR,
544 Json(HelpResponse {
545 content: String::new(),
546 error: Some(format!("Failed to parse LLM response: {e}")),
547 }),
548 );
549 }
550 };
551
552 let content = body
553 .pointer("/choices/0/message/content")
554 .and_then(|c| c.as_str())
555 .unwrap_or("No response from LLM")
556 .to_string();
557
558 (
559 StatusCode::OK,
560 Json(HelpResponse {
561 content,
562 error: None,
563 }),
564 )
565}
566
567async fn handle_health(State(state): State<Arc<ProxyState>>) -> impl IntoResponse {
568 let auth = if state.jwt_config.is_some() {
569 "jwt"
570 } else {
571 "disabled"
572 };
573
574 Json(HealthResponse {
575 status: "ok".into(),
576 version: env!("CARGO_PKG_VERSION").into(),
577 tools: state.registry.list_public_tools().len(),
578 providers: state.registry.list_providers().len(),
579 skills: state.skill_registry.skill_count(),
580 auth: auth.into(),
581 })
582}
583
584async fn handle_jwks(State(state): State<Arc<ProxyState>>) -> impl IntoResponse {
586 match &state.jwks_json {
587 Some(jwks) => (StatusCode::OK, Json(jwks.clone())),
588 None => (
589 StatusCode::NOT_FOUND,
590 Json(serde_json::json!({"error": "JWKS not configured"})),
591 ),
592 }
593}
594
595async fn handle_mcp(
600 State(state): State<Arc<ProxyState>>,
601 Json(msg): Json<Value>,
602) -> impl IntoResponse {
603 let method = msg.get("method").and_then(|m| m.as_str()).unwrap_or("");
604 let id = msg.get("id").cloned();
605
606 tracing::debug!(%method, "POST /mcp");
607
608 match method {
609 "initialize" => {
610 let result = serde_json::json!({
611 "protocolVersion": "2025-03-26",
612 "capabilities": {
613 "tools": { "listChanged": false }
614 },
615 "serverInfo": {
616 "name": "ati-proxy",
617 "version": env!("CARGO_PKG_VERSION")
618 }
619 });
620 jsonrpc_success(id, result)
621 }
622
623 "notifications/initialized" => (StatusCode::ACCEPTED, Json(Value::Null)),
624
625 "tools/list" => {
626 let all_tools = state.registry.list_public_tools();
627 let mcp_tools: Vec<Value> = all_tools
628 .iter()
629 .map(|(_provider, tool)| {
630 serde_json::json!({
631 "name": tool.name,
632 "description": tool.description,
633 "inputSchema": tool.input_schema.clone().unwrap_or(serde_json::json!({
634 "type": "object",
635 "properties": {}
636 }))
637 })
638 })
639 .collect();
640
641 let result = serde_json::json!({
642 "tools": mcp_tools,
643 });
644 jsonrpc_success(id, result)
645 }
646
647 "tools/call" => {
648 let params = msg.get("params").cloned().unwrap_or(Value::Null);
649 let tool_name = params.get("name").and_then(|n| n.as_str()).unwrap_or("");
650 let arguments: HashMap<String, Value> = params
651 .get("arguments")
652 .and_then(|a| serde_json::from_value(a.clone()).ok())
653 .unwrap_or_default();
654
655 if tool_name.is_empty() {
656 return jsonrpc_error(id, -32602, "Missing tool name in params.name");
657 }
658
659 let (provider, _tool) = match state.registry.get_tool(tool_name) {
660 Some(pt) => pt,
661 None => {
662 return jsonrpc_error(id, -32602, &format!("Unknown tool: '{tool_name}'"));
663 }
664 };
665
666 tracing::debug!(%tool_name, provider = %provider.name, "MCP tools/call");
667
668 let mcp_gen_ctx = GenContext {
669 jwt_sub: "dev".into(),
670 jwt_scope: "*".into(),
671 tool_name: tool_name.to_string(),
672 timestamp: crate::core::jwt::now_secs(),
673 };
674
675 let result = if provider.is_mcp() {
676 mcp_client::execute_with_gen(
677 provider,
678 tool_name,
679 &arguments,
680 &state.keyring,
681 Some(&mcp_gen_ctx),
682 Some(&state.auth_cache),
683 )
684 .await
685 } else if provider.is_cli() {
686 let raw: Vec<String> = arguments
688 .iter()
689 .flat_map(|(k, v)| {
690 let val = match v {
691 Value::String(s) => s.clone(),
692 other => other.to_string(),
693 };
694 vec![format!("--{k}"), val]
695 })
696 .collect();
697 crate::core::cli_executor::execute_with_gen(
698 provider,
699 &raw,
700 &state.keyring,
701 Some(&mcp_gen_ctx),
702 Some(&state.auth_cache),
703 )
704 .await
705 .map_err(|e| mcp_client::McpError::Transport(e.to_string()))
706 } else {
707 match match provider.handler.as_str() {
708 "xai" => {
709 xai::execute_xai_tool(provider, _tool, &arguments, &state.keyring).await
710 }
711 _ => {
712 http::execute_tool_with_gen(
713 provider,
714 _tool,
715 &arguments,
716 &state.keyring,
717 Some(&mcp_gen_ctx),
718 Some(&state.auth_cache),
719 )
720 .await
721 }
722 } {
723 Ok(val) => Ok(val),
724 Err(e) => Err(mcp_client::McpError::Transport(e.to_string())),
725 }
726 };
727
728 match result {
729 Ok(value) => {
730 let text = match &value {
731 Value::String(s) => s.clone(),
732 other => serde_json::to_string_pretty(other).unwrap_or_default(),
733 };
734 let mcp_result = serde_json::json!({
735 "content": [{"type": "text", "text": text}],
736 "isError": false,
737 });
738 jsonrpc_success(id, mcp_result)
739 }
740 Err(e) => {
741 let mcp_result = serde_json::json!({
742 "content": [{"type": "text", "text": format!("Error: {e}")}],
743 "isError": true,
744 });
745 jsonrpc_success(id, mcp_result)
746 }
747 }
748 }
749
750 _ => jsonrpc_error(id, -32601, &format!("Method not found: '{method}'")),
751 }
752}
753
754fn jsonrpc_success(id: Option<Value>, result: Value) -> (StatusCode, Json<Value>) {
755 (
756 StatusCode::OK,
757 Json(serde_json::json!({
758 "jsonrpc": "2.0",
759 "id": id,
760 "result": result,
761 })),
762 )
763}
764
765fn jsonrpc_error(id: Option<Value>, code: i64, message: &str) -> (StatusCode, Json<Value>) {
766 (
767 StatusCode::OK,
768 Json(serde_json::json!({
769 "jsonrpc": "2.0",
770 "id": id,
771 "error": {
772 "code": code,
773 "message": message,
774 }
775 })),
776 )
777}
778
779async fn handle_skills_list(
784 State(state): State<Arc<ProxyState>>,
785 axum::extract::Query(query): axum::extract::Query<SkillsQuery>,
786) -> impl IntoResponse {
787 tracing::debug!(
788 category = ?query.category,
789 provider = ?query.provider,
790 tool = ?query.tool,
791 search = ?query.search,
792 "GET /skills"
793 );
794
795 let skills: Vec<&skill::SkillMeta> = if let Some(search_query) = &query.search {
796 state.skill_registry.search(search_query)
797 } else if let Some(cat) = &query.category {
798 state.skill_registry.skills_for_category(cat)
799 } else if let Some(prov) = &query.provider {
800 state.skill_registry.skills_for_provider(prov)
801 } else if let Some(t) = &query.tool {
802 state.skill_registry.skills_for_tool(t)
803 } else {
804 state.skill_registry.list_skills().iter().collect()
805 };
806
807 let json: Vec<Value> = skills
808 .iter()
809 .map(|s| {
810 serde_json::json!({
811 "name": s.name,
812 "version": s.version,
813 "description": s.description,
814 "tools": s.tools,
815 "providers": s.providers,
816 "categories": s.categories,
817 "hint": s.hint,
818 })
819 })
820 .collect();
821
822 (StatusCode::OK, Json(Value::Array(json)))
823}
824
825async fn handle_skill_detail(
826 State(state): State<Arc<ProxyState>>,
827 axum::extract::Path(name): axum::extract::Path<String>,
828 axum::extract::Query(query): axum::extract::Query<SkillDetailQuery>,
829) -> impl IntoResponse {
830 tracing::debug!(%name, meta = ?query.meta, refs = ?query.refs, "GET /skills/:name");
831
832 let skill_meta = match state.skill_registry.get_skill(&name) {
833 Some(s) => s,
834 None => {
835 return (
836 StatusCode::NOT_FOUND,
837 Json(serde_json::json!({"error": format!("Skill '{name}' not found")})),
838 );
839 }
840 };
841
842 if query.meta.unwrap_or(false) {
843 return (
844 StatusCode::OK,
845 Json(serde_json::json!({
846 "name": skill_meta.name,
847 "version": skill_meta.version,
848 "description": skill_meta.description,
849 "author": skill_meta.author,
850 "tools": skill_meta.tools,
851 "providers": skill_meta.providers,
852 "categories": skill_meta.categories,
853 "keywords": skill_meta.keywords,
854 "hint": skill_meta.hint,
855 "depends_on": skill_meta.depends_on,
856 "suggests": skill_meta.suggests,
857 "license": skill_meta.license,
858 "compatibility": skill_meta.compatibility,
859 "allowed_tools": skill_meta.allowed_tools,
860 "format": skill_meta.format,
861 })),
862 );
863 }
864
865 let content = match state.skill_registry.read_content(&name) {
866 Ok(c) => c,
867 Err(e) => {
868 return (
869 StatusCode::INTERNAL_SERVER_ERROR,
870 Json(serde_json::json!({"error": format!("Failed to read skill: {e}")})),
871 );
872 }
873 };
874
875 let mut response = serde_json::json!({
876 "name": skill_meta.name,
877 "version": skill_meta.version,
878 "description": skill_meta.description,
879 "content": content,
880 });
881
882 if query.refs.unwrap_or(false) {
883 if let Ok(refs) = state.skill_registry.list_references(&name) {
884 response["references"] = serde_json::json!(refs);
885 }
886 }
887
888 (StatusCode::OK, Json(response))
889}
890
891async fn handle_skills_resolve(
892 State(state): State<Arc<ProxyState>>,
893 Json(req): Json<SkillResolveRequest>,
894) -> impl IntoResponse {
895 tracing::debug!(scopes = ?req.scopes, "POST /skills/resolve");
896
897 let scopes = ScopeConfig {
898 scopes: req.scopes,
899 sub: String::new(),
900 expires_at: 0,
901 rate_config: None,
902 };
903
904 let resolved = skill::resolve_skills(&state.skill_registry, &state.registry, &scopes);
905
906 let json: Vec<Value> = resolved
907 .iter()
908 .map(|s| {
909 serde_json::json!({
910 "name": s.name,
911 "version": s.version,
912 "description": s.description,
913 "tools": s.tools,
914 "providers": s.providers,
915 "categories": s.categories,
916 })
917 })
918 .collect();
919
920 (StatusCode::OK, Json(Value::Array(json)))
921}
922
923async fn auth_middleware(
931 State(state): State<Arc<ProxyState>>,
932 mut req: HttpRequest<Body>,
933 next: Next,
934) -> Result<Response, StatusCode> {
935 let path = req.uri().path();
936
937 if path == "/health" || path == "/.well-known/jwks.json" {
939 return Ok(next.run(req).await);
940 }
941
942 let jwt_config = match &state.jwt_config {
944 Some(c) => c,
945 None => return Ok(next.run(req).await),
946 };
947
948 let auth_header = req
950 .headers()
951 .get("authorization")
952 .and_then(|v| v.to_str().ok());
953
954 let token = match auth_header {
955 Some(header) if header.starts_with("Bearer ") => &header[7..],
956 _ => return Err(StatusCode::UNAUTHORIZED),
957 };
958
959 match jwt::validate(token, jwt_config) {
961 Ok(claims) => {
962 tracing::debug!(sub = %claims.sub, scopes = %claims.scope, "JWT validated");
963 req.extensions_mut().insert(claims);
964 Ok(next.run(req).await)
965 }
966 Err(e) => {
967 tracing::debug!(error = %e, "JWT validation failed");
968 Err(StatusCode::UNAUTHORIZED)
969 }
970 }
971}
972
973pub fn build_router(state: Arc<ProxyState>) -> Router {
977 Router::new()
978 .route("/call", post(handle_call))
979 .route("/help", post(handle_help))
980 .route("/mcp", post(handle_mcp))
981 .route("/skills", get(handle_skills_list))
982 .route("/skills/resolve", post(handle_skills_resolve))
983 .route("/skills/{name}", get(handle_skill_detail))
984 .route("/health", get(handle_health))
985 .route("/.well-known/jwks.json", get(handle_jwks))
986 .layer(middleware::from_fn_with_state(
987 state.clone(),
988 auth_middleware,
989 ))
990 .with_state(state)
991}
992
993pub async fn run(
997 port: u16,
998 bind_addr: Option<String>,
999 ati_dir: PathBuf,
1000 _verbose: bool,
1001 env_keys: bool,
1002) -> Result<(), Box<dyn std::error::Error>> {
1003 let manifests_dir = ati_dir.join("manifests");
1005 let registry = ManifestRegistry::load(&manifests_dir)?;
1006
1007 let tool_count = registry.list_public_tools().len();
1008 let provider_count = registry.list_providers().len();
1009
1010 let keyring_source;
1012 let keyring = if env_keys {
1013 let kr = Keyring::from_env();
1015 let key_names = kr.key_names();
1016 tracing::info!(
1017 count = key_names.len(),
1018 "loaded API keys from ATI_KEY_* env vars"
1019 );
1020 for name in &key_names {
1021 tracing::debug!(key = %name, "env key loaded");
1022 }
1023 keyring_source = "env-vars (ATI_KEY_*)";
1024 kr
1025 } else {
1026 let keyring_path = ati_dir.join("keyring.enc");
1028 if keyring_path.exists() {
1029 if let Ok(kr) = Keyring::load(&keyring_path) {
1030 keyring_source = "keyring.enc (sealed key)";
1031 kr
1032 } else if let Ok(kr) = Keyring::load_local(&keyring_path, &ati_dir) {
1033 keyring_source = "keyring.enc (persistent key)";
1034 kr
1035 } else {
1036 tracing::warn!("keyring.enc exists but could not be decrypted");
1037 keyring_source = "empty (decryption failed)";
1038 Keyring::empty()
1039 }
1040 } else {
1041 let creds_path = ati_dir.join("credentials");
1042 if creds_path.exists() {
1043 match Keyring::load_credentials(&creds_path) {
1044 Ok(kr) => {
1045 keyring_source = "credentials (plaintext)";
1046 kr
1047 }
1048 Err(e) => {
1049 tracing::warn!(error = %e, "failed to load credentials");
1050 keyring_source = "empty (credentials error)";
1051 Keyring::empty()
1052 }
1053 }
1054 } else {
1055 tracing::warn!("no keyring.enc or credentials found — running without API keys");
1056 tracing::warn!("tools requiring authentication will fail");
1057 keyring_source = "empty (no auth)";
1058 Keyring::empty()
1059 }
1060 }
1061 };
1062
1063 let mcp_providers: Vec<(String, String)> = registry
1065 .list_mcp_providers()
1066 .iter()
1067 .map(|p| (p.name.clone(), p.mcp_transport_type().to_string()))
1068 .collect();
1069 let mcp_count = mcp_providers.len();
1070 let openapi_providers: Vec<String> = registry
1071 .list_openapi_providers()
1072 .iter()
1073 .map(|p| p.name.clone())
1074 .collect();
1075 let openapi_count = openapi_providers.len();
1076
1077 let skills_dir = ati_dir.join("skills");
1079 let skill_registry = SkillRegistry::load(&skills_dir).unwrap_or_else(|e| {
1080 tracing::warn!(error = %e, "failed to load skills");
1081 SkillRegistry::load(std::path::Path::new("/nonexistent-fallback")).unwrap()
1082 });
1083 let skill_count = skill_registry.skill_count();
1084
1085 let jwt_config = match jwt::config_from_env() {
1087 Ok(config) => config,
1088 Err(e) => {
1089 tracing::warn!(error = %e, "JWT config error");
1090 None
1091 }
1092 };
1093
1094 let auth_status = if jwt_config.is_some() {
1095 "JWT enabled"
1096 } else {
1097 "DISABLED (no JWT keys configured)"
1098 };
1099
1100 let jwks_json = jwt_config.as_ref().and_then(|config| {
1102 config
1103 .public_key_pem
1104 .as_ref()
1105 .and_then(|pem| jwt::public_key_to_jwks(pem, config.algorithm, "ati-proxy-1").ok())
1106 });
1107
1108 let state = Arc::new(ProxyState {
1109 registry,
1110 skill_registry,
1111 keyring,
1112 jwt_config,
1113 jwks_json,
1114 auth_cache: AuthCache::new(),
1115 });
1116
1117 let app = build_router(state);
1118
1119 let addr: SocketAddr = if let Some(ref bind) = bind_addr {
1120 format!("{bind}:{port}").parse()?
1121 } else {
1122 SocketAddr::from(([127, 0, 0, 1], port))
1123 };
1124
1125 tracing::info!(
1126 version = env!("CARGO_PKG_VERSION"),
1127 %addr,
1128 auth = auth_status,
1129 ati_dir = %ati_dir.display(),
1130 tools = tool_count,
1131 providers = provider_count,
1132 mcp = mcp_count,
1133 openapi = openapi_count,
1134 skills = skill_count,
1135 keyring = keyring_source,
1136 "ATI proxy server starting"
1137 );
1138 for (name, transport) in &mcp_providers {
1139 tracing::info!(provider = %name, transport = %transport, "MCP provider");
1140 }
1141 for name in &openapi_providers {
1142 tracing::info!(provider = %name, "OpenAPI provider");
1143 }
1144
1145 let listener = tokio::net::TcpListener::bind(addr).await?;
1146 axum::serve(listener, app).await?;
1147
1148 Ok(())
1149}
1150
1151fn write_proxy_audit(
1153 call_req: &CallRequest,
1154 agent_sub: &str,
1155 duration: std::time::Duration,
1156 error: Option<&str>,
1157) {
1158 let entry = crate::core::audit::AuditEntry {
1159 ts: chrono::Utc::now().to_rfc3339(),
1160 tool: call_req.tool_name.clone(),
1161 args: crate::core::audit::sanitize_args(&call_req.args),
1162 status: if error.is_some() {
1163 crate::core::audit::AuditStatus::Error
1164 } else {
1165 crate::core::audit::AuditStatus::Ok
1166 },
1167 duration_ms: duration.as_millis() as u64,
1168 agent_sub: agent_sub.to_string(),
1169 error: error.map(|s| s.to_string()),
1170 exit_code: None,
1171 };
1172 let _ = crate::core::audit::append(&entry);
1173}
1174
1175const HELP_SYSTEM_PROMPT: &str = r#"You are a helpful assistant for an AI agent that uses external tools via the `ati` CLI.
1178
1179## Available Tools
1180{tools}
1181
1182{skills_section}
1183
1184Answer the agent's question naturally, like a knowledgeable colleague would. Keep it short but useful:
1185
1186- Explain which tools to use and why, with `ati run` commands showing realistic parameter values
1187- If multiple steps are needed, walk through them briefly in order
1188- Mention important gotchas or parameter choices that matter
1189- If skills are relevant, suggest `ati skill show <name>` for the full methodology
1190
1191Keep your answer concise — a few short paragraphs with embedded code blocks. Only recommend tools from the list above."#;
1192
1193fn build_tool_context(
1194 tools: &[(
1195 &crate::core::manifest::Provider,
1196 &crate::core::manifest::Tool,
1197 )],
1198) -> String {
1199 let mut summaries = Vec::new();
1200 for (provider, tool) in tools {
1201 let mut summary = if let Some(cat) = &provider.category {
1202 format!(
1203 "- **{}** (provider: {}, category: {}): {}",
1204 tool.name, provider.name, cat, tool.description
1205 )
1206 } else {
1207 format!(
1208 "- **{}** (provider: {}): {}",
1209 tool.name, provider.name, tool.description
1210 )
1211 };
1212 if !tool.tags.is_empty() {
1213 summary.push_str(&format!("\n Tags: {}", tool.tags.join(", ")));
1214 }
1215 if provider.is_cli() && tool.input_schema.is_none() {
1217 let cmd = provider.cli_command.as_deref().unwrap_or("?");
1218 summary.push_str(&format!(
1219 "\n Usage: `ati run {} -- <args>` (passthrough to `{}`)",
1220 tool.name, cmd
1221 ));
1222 } else if let Some(schema) = &tool.input_schema {
1223 if let Some(props) = schema.get("properties") {
1224 if let Some(obj) = props.as_object() {
1225 let params: Vec<String> = obj
1226 .iter()
1227 .filter(|(_, v)| {
1228 v.get("x-ati-param-location").is_none()
1229 || v.get("description").is_some()
1230 })
1231 .map(|(k, v)| {
1232 let type_str =
1233 v.get("type").and_then(|t| t.as_str()).unwrap_or("string");
1234 let desc = v.get("description").and_then(|d| d.as_str()).unwrap_or("");
1235 format!(" --{k} ({type_str}): {desc}")
1236 })
1237 .collect();
1238 if !params.is_empty() {
1239 summary.push_str("\n Parameters:\n");
1240 summary.push_str(¶ms.join("\n"));
1241 }
1242 }
1243 }
1244 }
1245 summaries.push(summary);
1246 }
1247 summaries.join("\n\n")
1248}
1249
1250fn build_scoped_prompt(
1254 scope_name: &str,
1255 registry: &ManifestRegistry,
1256 skills_section: &str,
1257) -> Option<String> {
1258 if let Some((provider, tool)) = registry.get_tool(scope_name) {
1260 let mut details = format!(
1261 "**Name**: `{}`\n**Provider**: {} (handler: {})\n**Description**: {}\n",
1262 tool.name, provider.name, provider.handler, tool.description
1263 );
1264 if let Some(cat) = &provider.category {
1265 details.push_str(&format!("**Category**: {}\n", cat));
1266 }
1267 if provider.is_cli() {
1268 let cmd = provider.cli_command.as_deref().unwrap_or("?");
1269 details.push_str(&format!(
1270 "\n**Usage**: `ati run {} -- <args>` (passthrough to `{}`)\n",
1271 tool.name, cmd
1272 ));
1273 } else if let Some(schema) = &tool.input_schema {
1274 if let Some(props) = schema.get("properties").and_then(|p| p.as_object()) {
1275 let required: Vec<String> = schema
1276 .get("required")
1277 .and_then(|r| r.as_array())
1278 .map(|arr| {
1279 arr.iter()
1280 .filter_map(|v| v.as_str().map(|s| s.to_string()))
1281 .collect()
1282 })
1283 .unwrap_or_default();
1284 details.push_str("\n**Parameters**:\n");
1285 for (key, val) in props {
1286 let type_str = val.get("type").and_then(|t| t.as_str()).unwrap_or("string");
1287 let desc = val
1288 .get("description")
1289 .and_then(|d| d.as_str())
1290 .unwrap_or("");
1291 let req = if required.contains(key) {
1292 " **(required)**"
1293 } else {
1294 ""
1295 };
1296 details.push_str(&format!("- `--{key}` ({type_str}{req}): {desc}\n"));
1297 }
1298 }
1299 }
1300
1301 let prompt = format!(
1302 "You are an expert assistant for the `{}` tool, accessed via the `ati` CLI.\n\n\
1303 ## Tool Details\n{}\n\n{}\n\n\
1304 Answer the agent's question about this specific tool. Provide exact commands, explain flags and options, and give practical examples. Be concise and actionable.",
1305 tool.name, details, skills_section
1306 );
1307 return Some(prompt);
1308 }
1309
1310 if registry.has_provider(scope_name) {
1312 let tools = registry.tools_by_provider(scope_name);
1313 if tools.is_empty() {
1314 return None;
1315 }
1316 let tools_context = build_tool_context(&tools);
1317 let prompt = format!(
1318 "You are an expert assistant for the `{}` provider's tools, accessed via the `ati` CLI.\n\n\
1319 ## Tools in provider `{}`\n{}\n\n{}\n\n\
1320 Answer the agent's question about these tools. Provide exact `ati run` commands, explain parameters, and give practical examples. Be concise and actionable.",
1321 scope_name, scope_name, tools_context, skills_section
1322 );
1323 return Some(prompt);
1324 }
1325
1326 None
1327}