1use axum::{
8 body::Body,
9 extract::State,
10 http::{Request as HttpRequest, StatusCode},
11 middleware::{self, Next},
12 response::{IntoResponse, Response},
13 routing::{get, post},
14 Json, Router,
15};
16use serde::{Deserialize, Serialize};
17use serde_json::Value;
18use std::collections::HashMap;
19use std::net::SocketAddr;
20use std::path::PathBuf;
21use std::sync::Arc;
22
23use crate::core::auth_generator::{AuthCache, GenContext};
24use crate::core::http;
25use crate::core::jwt::{self, JwtConfig, TokenClaims};
26use crate::core::keyring::Keyring;
27use crate::core::manifest::ManifestRegistry;
28use crate::core::mcp_client;
29use crate::core::response;
30use crate::core::scope::ScopeConfig;
31use crate::core::skill::{self, SkillRegistry};
32use crate::core::xai;
33
34pub struct ProxyState {
36 pub registry: ManifestRegistry,
37 pub skill_registry: SkillRegistry,
38 pub keyring: Keyring,
39 pub verbose: bool,
40 pub jwt_config: Option<JwtConfig>,
42 pub jwks_json: Option<Value>,
44 pub auth_cache: AuthCache,
46}
47
48#[derive(Debug, Deserialize)]
51pub struct CallRequest {
52 pub tool_name: String,
53 pub args: HashMap<String, Value>,
54 #[serde(default)]
55 pub raw_args: Option<Vec<String>>,
56}
57
58#[derive(Debug, Serialize)]
59pub struct CallResponse {
60 pub result: Value,
61 #[serde(skip_serializing_if = "Option::is_none")]
62 pub error: Option<String>,
63}
64
65#[derive(Debug, Deserialize)]
66pub struct HelpRequest {
67 pub query: String,
68 #[serde(default)]
69 pub tool: Option<String>,
70}
71
72#[derive(Debug, Serialize)]
73pub struct HelpResponse {
74 pub content: String,
75 #[serde(skip_serializing_if = "Option::is_none")]
76 pub error: Option<String>,
77}
78
79#[derive(Debug, Serialize)]
80pub struct HealthResponse {
81 pub status: String,
82 pub version: String,
83 pub tools: usize,
84 pub providers: usize,
85 pub skills: usize,
86 pub auth: String,
87}
88
89#[derive(Debug, Deserialize)]
92pub struct SkillsQuery {
93 #[serde(default)]
94 pub category: Option<String>,
95 #[serde(default)]
96 pub provider: Option<String>,
97 #[serde(default)]
98 pub tool: Option<String>,
99 #[serde(default)]
100 pub search: Option<String>,
101}
102
103#[derive(Debug, Deserialize)]
104pub struct SkillDetailQuery {
105 #[serde(default)]
106 pub meta: Option<bool>,
107 #[serde(default)]
108 pub refs: Option<bool>,
109}
110
111#[derive(Debug, Deserialize)]
112pub struct SkillResolveRequest {
113 pub scopes: Vec<String>,
114}
115
116async fn handle_call(
119 State(state): State<Arc<ProxyState>>,
120 req: HttpRequest<Body>,
121) -> impl IntoResponse {
122 let claims = req.extensions().get::<TokenClaims>().cloned();
124
125 let body_bytes = match axum::body::to_bytes(req.into_body(), 10 * 1024 * 1024).await {
127 Ok(b) => b,
128 Err(e) => {
129 return (
130 StatusCode::BAD_REQUEST,
131 Json(CallResponse {
132 result: Value::Null,
133 error: Some(format!("Failed to read request body: {e}")),
134 }),
135 );
136 }
137 };
138
139 let call_req: CallRequest = match serde_json::from_slice(&body_bytes) {
140 Ok(r) => r,
141 Err(e) => {
142 return (
143 StatusCode::UNPROCESSABLE_ENTITY,
144 Json(CallResponse {
145 result: Value::Null,
146 error: Some(format!("Invalid request: {e}")),
147 }),
148 );
149 }
150 };
151
152 if state.verbose {
153 eprintln!(
154 "[proxy] /call tool={} args={:?}",
155 call_req.tool_name, call_req.args
156 );
157 }
158
159 let (provider, tool) = match state.registry.get_tool(&call_req.tool_name) {
161 Some(pt) => pt,
162 None => {
163 return (
164 StatusCode::NOT_FOUND,
165 Json(CallResponse {
166 result: Value::Null,
167 error: Some(format!("Unknown tool: '{}'", call_req.tool_name)),
168 }),
169 );
170 }
171 };
172
173 if let Some(tool_scope) = &tool.scope {
175 let scopes = match &claims {
176 Some(c) => ScopeConfig::from_jwt(c),
177 None if state.jwt_config.is_none() => ScopeConfig::unrestricted(), None => {
179 return (
180 StatusCode::FORBIDDEN,
181 Json(CallResponse {
182 result: Value::Null,
183 error: Some("Authentication required — no JWT provided".into()),
184 }),
185 );
186 }
187 };
188
189 if let Err(e) = scopes.check_access(&call_req.tool_name, tool_scope) {
190 return (
191 StatusCode::FORBIDDEN,
192 Json(CallResponse {
193 result: Value::Null,
194 error: Some(format!("Access denied: {e}")),
195 }),
196 );
197 }
198 }
199
200 {
202 let scopes = match &claims {
203 Some(c) => ScopeConfig::from_jwt(c),
204 None => ScopeConfig::unrestricted(),
205 };
206 if let Some(ref rate_config) = scopes.rate_config {
207 if let Err(e) = crate::core::rate::check_and_record(&call_req.tool_name, rate_config) {
208 return (
209 StatusCode::TOO_MANY_REQUESTS,
210 Json(CallResponse {
211 result: Value::Null,
212 error: Some(format!("{e}")),
213 }),
214 );
215 }
216 }
217 }
218
219 let gen_ctx = GenContext {
221 jwt_sub: claims
222 .as_ref()
223 .map(|c| c.sub.clone())
224 .unwrap_or_else(|| "dev".into()),
225 jwt_scope: claims
226 .as_ref()
227 .map(|c| c.scope.clone())
228 .unwrap_or_else(|| "*".into()),
229 tool_name: call_req.tool_name.clone(),
230 timestamp: crate::core::jwt::now_secs(),
231 };
232
233 let agent_sub = claims.as_ref().map(|c| c.sub.clone()).unwrap_or_default();
235 let start = std::time::Instant::now();
236
237 let response = match provider.handler.as_str() {
238 "mcp" => {
239 match mcp_client::execute_with_gen(
240 provider,
241 &call_req.tool_name,
242 &call_req.args,
243 &state.keyring,
244 Some(&gen_ctx),
245 Some(&state.auth_cache),
246 )
247 .await
248 {
249 Ok(result) => (
250 StatusCode::OK,
251 Json(CallResponse {
252 result,
253 error: None,
254 }),
255 ),
256 Err(e) => (
257 StatusCode::BAD_GATEWAY,
258 Json(CallResponse {
259 result: Value::Null,
260 error: Some(format!("MCP error: {e}")),
261 }),
262 ),
263 }
264 }
265 "cli" => {
266 let raw = call_req.raw_args.as_deref().unwrap_or(&[]);
267 match crate::core::cli_executor::execute_with_gen(
268 provider,
269 raw,
270 &state.keyring,
271 Some(&gen_ctx),
272 Some(&state.auth_cache),
273 )
274 .await
275 {
276 Ok(result) => (
277 StatusCode::OK,
278 Json(CallResponse {
279 result,
280 error: None,
281 }),
282 ),
283 Err(e) => (
284 StatusCode::BAD_GATEWAY,
285 Json(CallResponse {
286 result: Value::Null,
287 error: Some(format!("CLI error: {e}")),
288 }),
289 ),
290 }
291 }
292 _ => {
293 let raw_response = match match provider.handler.as_str() {
294 "xai" => {
295 xai::execute_xai_tool(provider, tool, &call_req.args, &state.keyring).await
296 }
297 _ => {
298 http::execute_tool_with_gen(
299 provider,
300 tool,
301 &call_req.args,
302 &state.keyring,
303 Some(&gen_ctx),
304 Some(&state.auth_cache),
305 )
306 .await
307 }
308 } {
309 Ok(resp) => resp,
310 Err(e) => {
311 let duration = start.elapsed();
312 write_proxy_audit(&call_req, &agent_sub, duration, Some(&e.to_string()));
313 return (
314 StatusCode::BAD_GATEWAY,
315 Json(CallResponse {
316 result: Value::Null,
317 error: Some(format!("Upstream API error: {e}")),
318 }),
319 );
320 }
321 };
322
323 let processed = match response::process_response(&raw_response, tool.response.as_ref())
324 {
325 Ok(p) => p,
326 Err(e) => {
327 let duration = start.elapsed();
328 write_proxy_audit(&call_req, &agent_sub, duration, Some(&e.to_string()));
329 return (
330 StatusCode::INTERNAL_SERVER_ERROR,
331 Json(CallResponse {
332 result: raw_response,
333 error: Some(format!("Response processing error: {e}")),
334 }),
335 );
336 }
337 };
338
339 (
340 StatusCode::OK,
341 Json(CallResponse {
342 result: processed,
343 error: None,
344 }),
345 )
346 }
347 };
348
349 let duration = start.elapsed();
350 let error_msg = response.1.error.as_deref();
351 write_proxy_audit(&call_req, &agent_sub, duration, error_msg);
352
353 response
354}
355
356async fn handle_help(
357 State(state): State<Arc<ProxyState>>,
358 Json(req): Json<HelpRequest>,
359) -> impl IntoResponse {
360 if state.verbose {
361 eprintln!("[proxy] /help query={} tool={:?}", req.query, req.tool);
362 }
363
364 let (llm_provider, llm_tool) = match state.registry.get_tool("_chat_completion") {
365 Some(pt) => pt,
366 None => {
367 return (
368 StatusCode::SERVICE_UNAVAILABLE,
369 Json(HelpResponse {
370 content: String::new(),
371 error: Some("No _llm.toml manifest found. Proxy help requires a configured LLM provider.".into()),
372 }),
373 );
374 }
375 };
376
377 let api_key = match llm_provider
378 .auth_key_name
379 .as_deref()
380 .and_then(|k| state.keyring.get(k))
381 {
382 Some(key) => key.to_string(),
383 None => {
384 return (
385 StatusCode::SERVICE_UNAVAILABLE,
386 Json(HelpResponse {
387 content: String::new(),
388 error: Some("LLM API key not found in keyring".into()),
389 }),
390 );
391 }
392 };
393
394 let scopes = ScopeConfig::unrestricted();
395 let resolved_skills = skill::resolve_skills(&state.skill_registry, &state.registry, &scopes);
396 let skills_section = if resolved_skills.is_empty() {
397 String::new()
398 } else {
399 format!(
400 "## Available Skills (methodology guides)\n{}",
401 skill::build_skill_context(&resolved_skills)
402 )
403 };
404
405 let system_prompt = if let Some(ref tool_name) = req.tool {
407 match build_scoped_prompt(tool_name, &state.registry, &skills_section) {
409 Some(prompt) => prompt,
410 None => {
411 if state.verbose {
413 eprintln!(
414 "[proxy] /help scope '{}' not found, falling back to unscoped",
415 tool_name
416 );
417 }
418 let all_tools = state.registry.list_public_tools();
419 let tools_context = build_tool_context(&all_tools);
420 HELP_SYSTEM_PROMPT
421 .replace("{tools}", &tools_context)
422 .replace("{skills_section}", &skills_section)
423 }
424 }
425 } else {
426 let all_tools = state.registry.list_public_tools();
427 let tools_context = build_tool_context(&all_tools);
428 HELP_SYSTEM_PROMPT
429 .replace("{tools}", &tools_context)
430 .replace("{skills_section}", &skills_section)
431 };
432
433 let request_body = serde_json::json!({
434 "model": "zai-glm-4.7",
435 "messages": [
436 {"role": "system", "content": system_prompt},
437 {"role": "user", "content": req.query}
438 ],
439 "max_completion_tokens": 1536,
440 "temperature": 0.3
441 });
442
443 let client = reqwest::Client::new();
444 let url = format!(
445 "{}{}",
446 llm_provider.base_url.trim_end_matches('/'),
447 llm_tool.endpoint
448 );
449
450 let response = match client
451 .post(&url)
452 .bearer_auth(&api_key)
453 .json(&request_body)
454 .send()
455 .await
456 {
457 Ok(r) => r,
458 Err(e) => {
459 return (
460 StatusCode::BAD_GATEWAY,
461 Json(HelpResponse {
462 content: String::new(),
463 error: Some(format!("LLM request failed: {e}")),
464 }),
465 );
466 }
467 };
468
469 if !response.status().is_success() {
470 let status = response.status();
471 let body = response.text().await.unwrap_or_default();
472 return (
473 StatusCode::BAD_GATEWAY,
474 Json(HelpResponse {
475 content: String::new(),
476 error: Some(format!("LLM API error ({status}): {body}")),
477 }),
478 );
479 }
480
481 let body: Value = match response.json().await {
482 Ok(b) => b,
483 Err(e) => {
484 return (
485 StatusCode::INTERNAL_SERVER_ERROR,
486 Json(HelpResponse {
487 content: String::new(),
488 error: Some(format!("Failed to parse LLM response: {e}")),
489 }),
490 );
491 }
492 };
493
494 let content = body
495 .pointer("/choices/0/message/content")
496 .and_then(|c| c.as_str())
497 .unwrap_or("No response from LLM")
498 .to_string();
499
500 (
501 StatusCode::OK,
502 Json(HelpResponse {
503 content,
504 error: None,
505 }),
506 )
507}
508
509async fn handle_health(State(state): State<Arc<ProxyState>>) -> impl IntoResponse {
510 let auth = if state.jwt_config.is_some() {
511 "jwt"
512 } else {
513 "disabled"
514 };
515
516 Json(HealthResponse {
517 status: "ok".into(),
518 version: env!("CARGO_PKG_VERSION").into(),
519 tools: state.registry.list_public_tools().len(),
520 providers: state.registry.list_providers().len(),
521 skills: state.skill_registry.skill_count(),
522 auth: auth.into(),
523 })
524}
525
526async fn handle_jwks(State(state): State<Arc<ProxyState>>) -> impl IntoResponse {
528 match &state.jwks_json {
529 Some(jwks) => (StatusCode::OK, Json(jwks.clone())),
530 None => (
531 StatusCode::NOT_FOUND,
532 Json(serde_json::json!({"error": "JWKS not configured"})),
533 ),
534 }
535}
536
537async fn handle_mcp(
542 State(state): State<Arc<ProxyState>>,
543 Json(msg): Json<Value>,
544) -> impl IntoResponse {
545 let method = msg.get("method").and_then(|m| m.as_str()).unwrap_or("");
546 let id = msg.get("id").cloned();
547
548 if state.verbose {
549 eprintln!("[proxy] /mcp method={method}");
550 }
551
552 match method {
553 "initialize" => {
554 let result = serde_json::json!({
555 "protocolVersion": "2025-03-26",
556 "capabilities": {
557 "tools": { "listChanged": false }
558 },
559 "serverInfo": {
560 "name": "ati-proxy",
561 "version": env!("CARGO_PKG_VERSION")
562 }
563 });
564 jsonrpc_success(id, result)
565 }
566
567 "notifications/initialized" => (StatusCode::ACCEPTED, Json(Value::Null)),
568
569 "tools/list" => {
570 let all_tools = state.registry.list_public_tools();
571 let mcp_tools: Vec<Value> = all_tools
572 .iter()
573 .map(|(_provider, tool)| {
574 serde_json::json!({
575 "name": tool.name,
576 "description": tool.description,
577 "inputSchema": tool.input_schema.clone().unwrap_or(serde_json::json!({
578 "type": "object",
579 "properties": {}
580 }))
581 })
582 })
583 .collect();
584
585 let result = serde_json::json!({
586 "tools": mcp_tools,
587 });
588 jsonrpc_success(id, result)
589 }
590
591 "tools/call" => {
592 let params = msg.get("params").cloned().unwrap_or(Value::Null);
593 let tool_name = params.get("name").and_then(|n| n.as_str()).unwrap_or("");
594 let arguments: HashMap<String, Value> = params
595 .get("arguments")
596 .and_then(|a| serde_json::from_value(a.clone()).ok())
597 .unwrap_or_default();
598
599 if tool_name.is_empty() {
600 return jsonrpc_error(id, -32602, "Missing tool name in params.name");
601 }
602
603 let (provider, _tool) = match state.registry.get_tool(tool_name) {
604 Some(pt) => pt,
605 None => {
606 return jsonrpc_error(id, -32602, &format!("Unknown tool: '{tool_name}'"));
607 }
608 };
609
610 if state.verbose {
611 eprintln!(
612 "[proxy] /mcp tools/call name={tool_name} provider={}",
613 provider.name
614 );
615 }
616
617 let mcp_gen_ctx = GenContext {
618 jwt_sub: "dev".into(),
619 jwt_scope: "*".into(),
620 tool_name: tool_name.to_string(),
621 timestamp: crate::core::jwt::now_secs(),
622 };
623
624 let result = if provider.is_mcp() {
625 mcp_client::execute_with_gen(
626 provider,
627 tool_name,
628 &arguments,
629 &state.keyring,
630 Some(&mcp_gen_ctx),
631 Some(&state.auth_cache),
632 )
633 .await
634 } else if provider.is_cli() {
635 let raw: Vec<String> = arguments
637 .iter()
638 .flat_map(|(k, v)| {
639 let val = match v {
640 Value::String(s) => s.clone(),
641 other => other.to_string(),
642 };
643 vec![format!("--{k}"), val]
644 })
645 .collect();
646 crate::core::cli_executor::execute_with_gen(
647 provider,
648 &raw,
649 &state.keyring,
650 Some(&mcp_gen_ctx),
651 Some(&state.auth_cache),
652 )
653 .await
654 .map_err(|e| mcp_client::McpError::Transport(e.to_string()))
655 } else {
656 match match provider.handler.as_str() {
657 "xai" => {
658 xai::execute_xai_tool(provider, _tool, &arguments, &state.keyring).await
659 }
660 _ => {
661 http::execute_tool_with_gen(
662 provider,
663 _tool,
664 &arguments,
665 &state.keyring,
666 Some(&mcp_gen_ctx),
667 Some(&state.auth_cache),
668 )
669 .await
670 }
671 } {
672 Ok(val) => Ok(val),
673 Err(e) => Err(mcp_client::McpError::Transport(e.to_string())),
674 }
675 };
676
677 match result {
678 Ok(value) => {
679 let text = match &value {
680 Value::String(s) => s.clone(),
681 other => serde_json::to_string_pretty(other).unwrap_or_default(),
682 };
683 let mcp_result = serde_json::json!({
684 "content": [{"type": "text", "text": text}],
685 "isError": false,
686 });
687 jsonrpc_success(id, mcp_result)
688 }
689 Err(e) => {
690 let mcp_result = serde_json::json!({
691 "content": [{"type": "text", "text": format!("Error: {e}")}],
692 "isError": true,
693 });
694 jsonrpc_success(id, mcp_result)
695 }
696 }
697 }
698
699 _ => jsonrpc_error(id, -32601, &format!("Method not found: '{method}'")),
700 }
701}
702
703fn jsonrpc_success(id: Option<Value>, result: Value) -> (StatusCode, Json<Value>) {
704 (
705 StatusCode::OK,
706 Json(serde_json::json!({
707 "jsonrpc": "2.0",
708 "id": id,
709 "result": result,
710 })),
711 )
712}
713
714fn jsonrpc_error(id: Option<Value>, code: i64, message: &str) -> (StatusCode, Json<Value>) {
715 (
716 StatusCode::OK,
717 Json(serde_json::json!({
718 "jsonrpc": "2.0",
719 "id": id,
720 "error": {
721 "code": code,
722 "message": message,
723 }
724 })),
725 )
726}
727
728async fn handle_skills_list(
733 State(state): State<Arc<ProxyState>>,
734 axum::extract::Query(query): axum::extract::Query<SkillsQuery>,
735) -> impl IntoResponse {
736 if state.verbose {
737 eprintln!(
738 "[proxy] GET /skills category={:?} provider={:?} tool={:?} search={:?}",
739 query.category, query.provider, query.tool, query.search
740 );
741 }
742
743 let skills: Vec<&skill::SkillMeta> = if let Some(search_query) = &query.search {
744 state.skill_registry.search(search_query)
745 } else if let Some(cat) = &query.category {
746 state.skill_registry.skills_for_category(cat)
747 } else if let Some(prov) = &query.provider {
748 state.skill_registry.skills_for_provider(prov)
749 } else if let Some(t) = &query.tool {
750 state.skill_registry.skills_for_tool(t)
751 } else {
752 state.skill_registry.list_skills().iter().collect()
753 };
754
755 let json: Vec<Value> = skills
756 .iter()
757 .map(|s| {
758 serde_json::json!({
759 "name": s.name,
760 "version": s.version,
761 "description": s.description,
762 "tools": s.tools,
763 "providers": s.providers,
764 "categories": s.categories,
765 "hint": s.hint,
766 })
767 })
768 .collect();
769
770 (StatusCode::OK, Json(Value::Array(json)))
771}
772
773async fn handle_skill_detail(
774 State(state): State<Arc<ProxyState>>,
775 axum::extract::Path(name): axum::extract::Path<String>,
776 axum::extract::Query(query): axum::extract::Query<SkillDetailQuery>,
777) -> impl IntoResponse {
778 if state.verbose {
779 eprintln!(
780 "[proxy] GET /skills/{name} meta={:?} refs={:?}",
781 query.meta, query.refs
782 );
783 }
784
785 let skill_meta = match state.skill_registry.get_skill(&name) {
786 Some(s) => s,
787 None => {
788 return (
789 StatusCode::NOT_FOUND,
790 Json(serde_json::json!({"error": format!("Skill '{name}' not found")})),
791 );
792 }
793 };
794
795 if query.meta.unwrap_or(false) {
796 return (
797 StatusCode::OK,
798 Json(serde_json::json!({
799 "name": skill_meta.name,
800 "version": skill_meta.version,
801 "description": skill_meta.description,
802 "author": skill_meta.author,
803 "tools": skill_meta.tools,
804 "providers": skill_meta.providers,
805 "categories": skill_meta.categories,
806 "keywords": skill_meta.keywords,
807 "hint": skill_meta.hint,
808 "depends_on": skill_meta.depends_on,
809 "suggests": skill_meta.suggests,
810 "license": skill_meta.license,
811 "compatibility": skill_meta.compatibility,
812 "allowed_tools": skill_meta.allowed_tools,
813 "format": skill_meta.format,
814 })),
815 );
816 }
817
818 let content = match state.skill_registry.read_content(&name) {
819 Ok(c) => c,
820 Err(e) => {
821 return (
822 StatusCode::INTERNAL_SERVER_ERROR,
823 Json(serde_json::json!({"error": format!("Failed to read skill: {e}")})),
824 );
825 }
826 };
827
828 let mut response = serde_json::json!({
829 "name": skill_meta.name,
830 "version": skill_meta.version,
831 "description": skill_meta.description,
832 "content": content,
833 });
834
835 if query.refs.unwrap_or(false) {
836 if let Ok(refs) = state.skill_registry.list_references(&name) {
837 response["references"] = serde_json::json!(refs);
838 }
839 }
840
841 (StatusCode::OK, Json(response))
842}
843
844async fn handle_skills_resolve(
845 State(state): State<Arc<ProxyState>>,
846 Json(req): Json<SkillResolveRequest>,
847) -> impl IntoResponse {
848 if state.verbose {
849 eprintln!("[proxy] POST /skills/resolve scopes={:?}", req.scopes);
850 }
851
852 let scopes = ScopeConfig {
853 scopes: req.scopes,
854 sub: String::new(),
855 expires_at: 0,
856 rate_config: None,
857 };
858
859 let resolved = skill::resolve_skills(&state.skill_registry, &state.registry, &scopes);
860
861 let json: Vec<Value> = resolved
862 .iter()
863 .map(|s| {
864 serde_json::json!({
865 "name": s.name,
866 "version": s.version,
867 "description": s.description,
868 "tools": s.tools,
869 "providers": s.providers,
870 "categories": s.categories,
871 })
872 })
873 .collect();
874
875 (StatusCode::OK, Json(Value::Array(json)))
876}
877
878async fn auth_middleware(
886 State(state): State<Arc<ProxyState>>,
887 mut req: HttpRequest<Body>,
888 next: Next,
889) -> Result<Response, StatusCode> {
890 let path = req.uri().path();
891
892 if path == "/health" || path == "/.well-known/jwks.json" {
894 return Ok(next.run(req).await);
895 }
896
897 let jwt_config = match &state.jwt_config {
899 Some(c) => c,
900 None => return Ok(next.run(req).await),
901 };
902
903 let auth_header = req
905 .headers()
906 .get("authorization")
907 .and_then(|v| v.to_str().ok());
908
909 let token = match auth_header {
910 Some(header) if header.starts_with("Bearer ") => &header[7..],
911 _ => return Err(StatusCode::UNAUTHORIZED),
912 };
913
914 match jwt::validate(token, jwt_config) {
916 Ok(claims) => {
917 if state.verbose {
918 eprintln!(
919 "[proxy] JWT validated: sub={} scopes={}",
920 claims.sub, claims.scope
921 );
922 }
923 req.extensions_mut().insert(claims);
924 Ok(next.run(req).await)
925 }
926 Err(e) => {
927 if state.verbose {
928 eprintln!("[proxy] JWT validation failed: {e}");
929 }
930 Err(StatusCode::UNAUTHORIZED)
931 }
932 }
933}
934
935pub fn build_router(state: Arc<ProxyState>) -> Router {
939 Router::new()
940 .route("/call", post(handle_call))
941 .route("/help", post(handle_help))
942 .route("/mcp", post(handle_mcp))
943 .route("/skills", get(handle_skills_list))
944 .route("/skills/resolve", post(handle_skills_resolve))
945 .route("/skills/{name}", get(handle_skill_detail))
946 .route("/health", get(handle_health))
947 .route("/.well-known/jwks.json", get(handle_jwks))
948 .layer(middleware::from_fn_with_state(
949 state.clone(),
950 auth_middleware,
951 ))
952 .with_state(state)
953}
954
955pub async fn run(
959 port: u16,
960 bind_addr: Option<String>,
961 ati_dir: PathBuf,
962 verbose: bool,
963 env_keys: bool,
964) -> Result<(), Box<dyn std::error::Error>> {
965 let manifests_dir = ati_dir.join("manifests");
967 let registry = ManifestRegistry::load(&manifests_dir)?;
968
969 let tool_count = registry.list_public_tools().len();
970 let provider_count = registry.list_providers().len();
971
972 let keyring_source;
974 let keyring = if env_keys {
975 let kr = Keyring::from_env();
977 let key_names = kr.key_names();
978 eprintln!(
979 " Loaded {} API keys from ATI_KEY_* env vars:",
980 key_names.len()
981 );
982 for name in &key_names {
983 eprintln!(" - {name}");
984 }
985 keyring_source = "env-vars (ATI_KEY_*)";
986 kr
987 } else {
988 let keyring_path = ati_dir.join("keyring.enc");
990 if keyring_path.exists() {
991 if let Ok(kr) = Keyring::load(&keyring_path) {
992 keyring_source = "keyring.enc (sealed key)";
993 kr
994 } else if let Ok(kr) = Keyring::load_local(&keyring_path, &ati_dir) {
995 keyring_source = "keyring.enc (persistent key)";
996 kr
997 } else {
998 eprintln!("Warning: keyring.enc exists but could not be decrypted");
999 keyring_source = "empty (decryption failed)";
1000 Keyring::empty()
1001 }
1002 } else {
1003 let creds_path = ati_dir.join("credentials");
1004 if creds_path.exists() {
1005 match Keyring::load_credentials(&creds_path) {
1006 Ok(kr) => {
1007 keyring_source = "credentials (plaintext)";
1008 kr
1009 }
1010 Err(e) => {
1011 eprintln!("Warning: Failed to load credentials: {e}");
1012 keyring_source = "empty (credentials error)";
1013 Keyring::empty()
1014 }
1015 }
1016 } else {
1017 eprintln!(
1018 "Warning: No keyring.enc or credentials found — running without API keys"
1019 );
1020 eprintln!("Tools requiring authentication will fail.");
1021 keyring_source = "empty (no auth)";
1022 Keyring::empty()
1023 }
1024 }
1025 };
1026
1027 let mcp_providers: Vec<(String, String)> = registry
1029 .list_mcp_providers()
1030 .iter()
1031 .map(|p| (p.name.clone(), p.mcp_transport_type().to_string()))
1032 .collect();
1033 let mcp_count = mcp_providers.len();
1034 let openapi_providers: Vec<String> = registry
1035 .list_openapi_providers()
1036 .iter()
1037 .map(|p| p.name.clone())
1038 .collect();
1039 let openapi_count = openapi_providers.len();
1040
1041 let skills_dir = ati_dir.join("skills");
1043 let skill_registry = SkillRegistry::load(&skills_dir).unwrap_or_else(|e| {
1044 if verbose {
1045 eprintln!("Warning: failed to load skills: {e}");
1046 }
1047 SkillRegistry::load(std::path::Path::new("/nonexistent-fallback")).unwrap()
1048 });
1049 let skill_count = skill_registry.skill_count();
1050
1051 let jwt_config = match jwt::config_from_env() {
1053 Ok(config) => config,
1054 Err(e) => {
1055 eprintln!("Warning: JWT config error: {e}");
1056 None
1057 }
1058 };
1059
1060 let auth_status = if jwt_config.is_some() {
1061 "JWT enabled"
1062 } else {
1063 "DISABLED (no JWT keys configured)"
1064 };
1065
1066 let jwks_json = jwt_config.as_ref().and_then(|config| {
1068 config
1069 .public_key_pem
1070 .as_ref()
1071 .and_then(|pem| jwt::public_key_to_jwks(pem, config.algorithm, "ati-proxy-1").ok())
1072 });
1073
1074 let state = Arc::new(ProxyState {
1075 registry,
1076 skill_registry,
1077 keyring,
1078 verbose,
1079 jwt_config,
1080 jwks_json,
1081 auth_cache: AuthCache::new(),
1082 });
1083
1084 let app = build_router(state);
1085
1086 let addr: SocketAddr = if let Some(ref bind) = bind_addr {
1087 format!("{bind}:{port}").parse()?
1088 } else {
1089 SocketAddr::from(([127, 0, 0, 1], port))
1090 };
1091
1092 eprintln!("ATI proxy server v{}", env!("CARGO_PKG_VERSION"));
1093 eprintln!(" Listening on http://{addr}");
1094 eprintln!(" Auth: {auth_status}");
1095 eprintln!(" ATI dir: {}", ati_dir.display());
1096 eprintln!(" Tools: {tool_count}, Providers: {provider_count} ({mcp_count} MCP, {openapi_count} OpenAPI)");
1097 eprintln!(" Skills: {skill_count}");
1098 eprintln!(" Keyring: {keyring_source}");
1099 eprintln!(" Endpoints: /call, /help, /mcp, /skills, /skills/:name, /skills/resolve, /health, /.well-known/jwks.json");
1100 for (name, transport) in &mcp_providers {
1101 eprintln!(" MCP: {name} ({transport})");
1102 }
1103 for name in &openapi_providers {
1104 eprintln!(" OpenAPI: {name}");
1105 }
1106
1107 let listener = tokio::net::TcpListener::bind(addr).await?;
1108 axum::serve(listener, app).await?;
1109
1110 Ok(())
1111}
1112
1113fn write_proxy_audit(
1115 call_req: &CallRequest,
1116 agent_sub: &str,
1117 duration: std::time::Duration,
1118 error: Option<&str>,
1119) {
1120 let entry = crate::core::audit::AuditEntry {
1121 ts: chrono::Utc::now().to_rfc3339(),
1122 tool: call_req.tool_name.clone(),
1123 args: crate::core::audit::sanitize_args(&serde_json::json!(call_req.args)),
1124 status: if error.is_some() {
1125 crate::core::audit::AuditStatus::Error
1126 } else {
1127 crate::core::audit::AuditStatus::Ok
1128 },
1129 duration_ms: duration.as_millis() as u64,
1130 agent_sub: agent_sub.to_string(),
1131 error: error.map(|s| s.to_string()),
1132 exit_code: None,
1133 };
1134 let _ = crate::core::audit::append(&entry);
1135}
1136
1137const HELP_SYSTEM_PROMPT: &str = r#"You are a helpful assistant for an AI agent that uses external tools via the `ati` CLI.
1140
1141## Available Tools
1142{tools}
1143
1144{skills_section}
1145
1146Answer the agent's question naturally, like a knowledgeable colleague would. Keep it short but useful:
1147
1148- Explain which tools to use and why, with `ati run` commands showing realistic parameter values
1149- If multiple steps are needed, walk through them briefly in order
1150- Mention important gotchas or parameter choices that matter
1151- If skills are relevant, suggest `ati skill show <name>` for the full methodology
1152
1153Keep your answer concise — a few short paragraphs with embedded code blocks. Only recommend tools from the list above."#;
1154
1155fn build_tool_context(
1156 tools: &[(
1157 &crate::core::manifest::Provider,
1158 &crate::core::manifest::Tool,
1159 )],
1160) -> String {
1161 let mut summaries = Vec::new();
1162 for (provider, tool) in tools {
1163 let mut summary = if let Some(cat) = &provider.category {
1164 format!(
1165 "- **{}** (provider: {}, category: {}): {}",
1166 tool.name, provider.name, cat, tool.description
1167 )
1168 } else {
1169 format!(
1170 "- **{}** (provider: {}): {}",
1171 tool.name, provider.name, tool.description
1172 )
1173 };
1174 if !tool.tags.is_empty() {
1175 summary.push_str(&format!("\n Tags: {}", tool.tags.join(", ")));
1176 }
1177 if provider.is_cli() && tool.input_schema.is_none() {
1179 let cmd = provider.cli_command.as_deref().unwrap_or("?");
1180 summary.push_str(&format!(
1181 "\n Usage: `ati run {} -- <args>` (passthrough to `{}`)",
1182 tool.name, cmd
1183 ));
1184 } else if let Some(schema) = &tool.input_schema {
1185 if let Some(props) = schema.get("properties") {
1186 if let Some(obj) = props.as_object() {
1187 let params: Vec<String> = obj
1188 .iter()
1189 .filter(|(_, v)| {
1190 v.get("x-ati-param-location").is_none()
1191 || v.get("description").is_some()
1192 })
1193 .map(|(k, v)| {
1194 let type_str =
1195 v.get("type").and_then(|t| t.as_str()).unwrap_or("string");
1196 let desc = v.get("description").and_then(|d| d.as_str()).unwrap_or("");
1197 format!(" --{k} ({type_str}): {desc}")
1198 })
1199 .collect();
1200 if !params.is_empty() {
1201 summary.push_str("\n Parameters:\n");
1202 summary.push_str(¶ms.join("\n"));
1203 }
1204 }
1205 }
1206 }
1207 summaries.push(summary);
1208 }
1209 summaries.join("\n\n")
1210}
1211
1212fn build_scoped_prompt(
1216 scope_name: &str,
1217 registry: &ManifestRegistry,
1218 skills_section: &str,
1219) -> Option<String> {
1220 if let Some((provider, tool)) = registry.get_tool(scope_name) {
1222 let mut details = format!(
1223 "**Name**: `{}`\n**Provider**: {} (handler: {})\n**Description**: {}\n",
1224 tool.name, provider.name, provider.handler, tool.description
1225 );
1226 if let Some(cat) = &provider.category {
1227 details.push_str(&format!("**Category**: {}\n", cat));
1228 }
1229 if provider.is_cli() {
1230 let cmd = provider.cli_command.as_deref().unwrap_or("?");
1231 details.push_str(&format!(
1232 "\n**Usage**: `ati run {} -- <args>` (passthrough to `{}`)\n",
1233 tool.name, cmd
1234 ));
1235 } else if let Some(schema) = &tool.input_schema {
1236 if let Some(props) = schema.get("properties").and_then(|p| p.as_object()) {
1237 let required: Vec<String> = schema
1238 .get("required")
1239 .and_then(|r| r.as_array())
1240 .map(|arr| {
1241 arr.iter()
1242 .filter_map(|v| v.as_str().map(|s| s.to_string()))
1243 .collect()
1244 })
1245 .unwrap_or_default();
1246 details.push_str("\n**Parameters**:\n");
1247 for (key, val) in props {
1248 let type_str = val.get("type").and_then(|t| t.as_str()).unwrap_or("string");
1249 let desc = val
1250 .get("description")
1251 .and_then(|d| d.as_str())
1252 .unwrap_or("");
1253 let req = if required.contains(key) {
1254 " **(required)**"
1255 } else {
1256 ""
1257 };
1258 details.push_str(&format!("- `--{key}` ({type_str}{req}): {desc}\n"));
1259 }
1260 }
1261 }
1262
1263 let prompt = format!(
1264 "You are an expert assistant for the `{}` tool, accessed via the `ati` CLI.\n\n\
1265 ## Tool Details\n{}\n\n{}\n\n\
1266 Answer the agent's question about this specific tool. Provide exact commands, explain flags and options, and give practical examples. Be concise and actionable.",
1267 tool.name, details, skills_section
1268 );
1269 return Some(prompt);
1270 }
1271
1272 if registry.has_provider(scope_name) {
1274 let tools = registry.tools_by_provider(scope_name);
1275 if tools.is_empty() {
1276 return None;
1277 }
1278 let tools_context = build_tool_context(&tools);
1279 let prompt = format!(
1280 "You are an expert assistant for the `{}` provider's tools, accessed via the `ati` CLI.\n\n\
1281 ## Tools in provider `{}`\n{}\n\n{}\n\n\
1282 Answer the agent's question about these tools. Provide exact `ati run` commands, explain parameters, and give practical examples. Be concise and actionable.",
1283 scope_name, scope_name, tools_context, skills_section
1284 );
1285 return Some(prompt);
1286 }
1287
1288 None
1289}