1use axum::{
8 body::Body,
9 extract::State,
10 http::{Request as HttpRequest, StatusCode},
11 middleware::{self, Next},
12 response::{IntoResponse, Response},
13 routing::{get, post},
14 Json, Router,
15};
16use serde::{Deserialize, Serialize};
17use serde_json::Value;
18use std::collections::HashMap;
19use std::net::SocketAddr;
20use std::path::PathBuf;
21use std::sync::Arc;
22
23use crate::core::auth_generator::{AuthCache, GenContext};
24use crate::core::http;
25use crate::core::jwt::{self, JwtConfig, TokenClaims};
26use crate::core::keyring::Keyring;
27use crate::core::manifest::ManifestRegistry;
28use crate::core::mcp_client;
29use crate::core::response;
30use crate::core::scope::ScopeConfig;
31use crate::core::skill::{self, SkillRegistry};
32use crate::core::xai;
33
34pub struct ProxyState {
36 pub registry: ManifestRegistry,
37 pub skill_registry: SkillRegistry,
38 pub keyring: Keyring,
39 pub verbose: bool,
40 pub jwt_config: Option<JwtConfig>,
42 pub jwks_json: Option<Value>,
44 pub auth_cache: AuthCache,
46}
47
48#[derive(Debug, Deserialize)]
51pub struct CallRequest {
52 pub tool_name: String,
53 #[serde(default = "default_args")]
57 pub args: Value,
58 #[serde(default)]
61 pub raw_args: Option<Vec<String>>,
62}
63
64fn default_args() -> Value {
65 Value::Object(serde_json::Map::new())
66}
67
68impl CallRequest {
69 fn args_as_map(&self) -> HashMap<String, Value> {
73 match &self.args {
74 Value::Object(map) => map.iter().map(|(k, v)| (k.clone(), v.clone())).collect(),
75 _ => HashMap::new(),
76 }
77 }
78
79 fn args_as_positional(&self) -> Vec<String> {
82 if let Some(ref raw) = self.raw_args {
84 return raw.clone();
85 }
86 match &self.args {
87 Value::Array(arr) => arr
89 .iter()
90 .filter_map(|v| match v {
91 Value::String(s) => Some(s.clone()),
92 other => Some(other.to_string()),
93 })
94 .collect(),
95 Value::String(s) => s.split_whitespace().map(String::from).collect(),
97 Value::Object(map) => {
99 if let Some(Value::Array(pos)) = map.get("_positional") {
100 return pos
101 .iter()
102 .filter_map(|v| match v {
103 Value::String(s) => Some(s.clone()),
104 other => Some(other.to_string()),
105 })
106 .collect();
107 }
108 let mut result = Vec::new();
110 for (k, v) in map {
111 result.push(format!("--{k}"));
112 match v {
113 Value::String(s) => result.push(s.clone()),
114 Value::Bool(true) => {} other => result.push(other.to_string()),
116 }
117 }
118 result
119 }
120 _ => Vec::new(),
121 }
122 }
123}
124
125#[derive(Debug, Serialize)]
126pub struct CallResponse {
127 pub result: Value,
128 #[serde(skip_serializing_if = "Option::is_none")]
129 pub error: Option<String>,
130}
131
132#[derive(Debug, Deserialize)]
133pub struct HelpRequest {
134 pub query: String,
135 #[serde(default)]
136 pub tool: Option<String>,
137}
138
139#[derive(Debug, Serialize)]
140pub struct HelpResponse {
141 pub content: String,
142 #[serde(skip_serializing_if = "Option::is_none")]
143 pub error: Option<String>,
144}
145
146#[derive(Debug, Serialize)]
147pub struct HealthResponse {
148 pub status: String,
149 pub version: String,
150 pub tools: usize,
151 pub providers: usize,
152 pub skills: usize,
153 pub auth: String,
154}
155
156#[derive(Debug, Deserialize)]
159pub struct SkillsQuery {
160 #[serde(default)]
161 pub category: Option<String>,
162 #[serde(default)]
163 pub provider: Option<String>,
164 #[serde(default)]
165 pub tool: Option<String>,
166 #[serde(default)]
167 pub search: Option<String>,
168}
169
170#[derive(Debug, Deserialize)]
171pub struct SkillDetailQuery {
172 #[serde(default)]
173 pub meta: Option<bool>,
174 #[serde(default)]
175 pub refs: Option<bool>,
176}
177
178#[derive(Debug, Deserialize)]
179pub struct SkillResolveRequest {
180 pub scopes: Vec<String>,
181}
182
183async fn handle_call(
186 State(state): State<Arc<ProxyState>>,
187 req: HttpRequest<Body>,
188) -> impl IntoResponse {
189 let claims = req.extensions().get::<TokenClaims>().cloned();
191
192 let body_bytes = match axum::body::to_bytes(req.into_body(), 10 * 1024 * 1024).await {
194 Ok(b) => b,
195 Err(e) => {
196 return (
197 StatusCode::BAD_REQUEST,
198 Json(CallResponse {
199 result: Value::Null,
200 error: Some(format!("Failed to read request body: {e}")),
201 }),
202 );
203 }
204 };
205
206 let call_req: CallRequest = match serde_json::from_slice(&body_bytes) {
207 Ok(r) => r,
208 Err(e) => {
209 return (
210 StatusCode::UNPROCESSABLE_ENTITY,
211 Json(CallResponse {
212 result: Value::Null,
213 error: Some(format!("Invalid request: {e}")),
214 }),
215 );
216 }
217 };
218
219 if state.verbose {
220 eprintln!(
221 "[proxy] /call tool={} args={:?}",
222 call_req.tool_name, call_req.args
223 );
224 }
225
226 let (provider, tool) = match state.registry.get_tool(&call_req.tool_name) {
228 Some(pt) => pt,
229 None => {
230 return (
231 StatusCode::NOT_FOUND,
232 Json(CallResponse {
233 result: Value::Null,
234 error: Some(format!("Unknown tool: '{}'", call_req.tool_name)),
235 }),
236 );
237 }
238 };
239
240 if let Some(tool_scope) = &tool.scope {
242 let scopes = match &claims {
243 Some(c) => ScopeConfig::from_jwt(c),
244 None if state.jwt_config.is_none() => ScopeConfig::unrestricted(), None => {
246 return (
247 StatusCode::FORBIDDEN,
248 Json(CallResponse {
249 result: Value::Null,
250 error: Some("Authentication required — no JWT provided".into()),
251 }),
252 );
253 }
254 };
255
256 if let Err(e) = scopes.check_access(&call_req.tool_name, tool_scope) {
257 return (
258 StatusCode::FORBIDDEN,
259 Json(CallResponse {
260 result: Value::Null,
261 error: Some(format!("Access denied: {e}")),
262 }),
263 );
264 }
265 }
266
267 {
269 let scopes = match &claims {
270 Some(c) => ScopeConfig::from_jwt(c),
271 None => ScopeConfig::unrestricted(),
272 };
273 if let Some(ref rate_config) = scopes.rate_config {
274 if let Err(e) = crate::core::rate::check_and_record(&call_req.tool_name, rate_config) {
275 return (
276 StatusCode::TOO_MANY_REQUESTS,
277 Json(CallResponse {
278 result: Value::Null,
279 error: Some(format!("{e}")),
280 }),
281 );
282 }
283 }
284 }
285
286 let gen_ctx = GenContext {
288 jwt_sub: claims
289 .as_ref()
290 .map(|c| c.sub.clone())
291 .unwrap_or_else(|| "dev".into()),
292 jwt_scope: claims
293 .as_ref()
294 .map(|c| c.scope.clone())
295 .unwrap_or_else(|| "*".into()),
296 tool_name: call_req.tool_name.clone(),
297 timestamp: crate::core::jwt::now_secs(),
298 };
299
300 let agent_sub = claims.as_ref().map(|c| c.sub.clone()).unwrap_or_default();
302 let start = std::time::Instant::now();
303
304 let response = match provider.handler.as_str() {
305 "mcp" => {
306 let args_map = call_req.args_as_map();
307 match mcp_client::execute_with_gen(
308 provider,
309 &call_req.tool_name,
310 &args_map,
311 &state.keyring,
312 Some(&gen_ctx),
313 Some(&state.auth_cache),
314 )
315 .await
316 {
317 Ok(result) => (
318 StatusCode::OK,
319 Json(CallResponse {
320 result,
321 error: None,
322 }),
323 ),
324 Err(e) => (
325 StatusCode::BAD_GATEWAY,
326 Json(CallResponse {
327 result: Value::Null,
328 error: Some(format!("MCP error: {e}")),
329 }),
330 ),
331 }
332 }
333 "cli" => {
334 let positional = call_req.args_as_positional();
335 match crate::core::cli_executor::execute_with_gen(
336 provider,
337 &positional,
338 &state.keyring,
339 Some(&gen_ctx),
340 Some(&state.auth_cache),
341 )
342 .await
343 {
344 Ok(result) => (
345 StatusCode::OK,
346 Json(CallResponse {
347 result,
348 error: None,
349 }),
350 ),
351 Err(e) => (
352 StatusCode::BAD_GATEWAY,
353 Json(CallResponse {
354 result: Value::Null,
355 error: Some(format!("CLI error: {e}")),
356 }),
357 ),
358 }
359 }
360 _ => {
361 let args_map = call_req.args_as_map();
362 let raw_response = match match provider.handler.as_str() {
363 "xai" => xai::execute_xai_tool(provider, tool, &args_map, &state.keyring).await,
364 _ => {
365 http::execute_tool_with_gen(
366 provider,
367 tool,
368 &args_map,
369 &state.keyring,
370 Some(&gen_ctx),
371 Some(&state.auth_cache),
372 )
373 .await
374 }
375 } {
376 Ok(resp) => resp,
377 Err(e) => {
378 let duration = start.elapsed();
379 write_proxy_audit(&call_req, &agent_sub, duration, Some(&e.to_string()));
380 return (
381 StatusCode::BAD_GATEWAY,
382 Json(CallResponse {
383 result: Value::Null,
384 error: Some(format!("Upstream API error: {e}")),
385 }),
386 );
387 }
388 };
389
390 let processed = match response::process_response(&raw_response, tool.response.as_ref())
391 {
392 Ok(p) => p,
393 Err(e) => {
394 let duration = start.elapsed();
395 write_proxy_audit(&call_req, &agent_sub, duration, Some(&e.to_string()));
396 return (
397 StatusCode::INTERNAL_SERVER_ERROR,
398 Json(CallResponse {
399 result: raw_response,
400 error: Some(format!("Response processing error: {e}")),
401 }),
402 );
403 }
404 };
405
406 (
407 StatusCode::OK,
408 Json(CallResponse {
409 result: processed,
410 error: None,
411 }),
412 )
413 }
414 };
415
416 let duration = start.elapsed();
417 let error_msg = response.1.error.as_deref();
418 write_proxy_audit(&call_req, &agent_sub, duration, error_msg);
419
420 response
421}
422
423async fn handle_help(
424 State(state): State<Arc<ProxyState>>,
425 Json(req): Json<HelpRequest>,
426) -> impl IntoResponse {
427 if state.verbose {
428 eprintln!("[proxy] /help query={} tool={:?}", req.query, req.tool);
429 }
430
431 let (llm_provider, llm_tool) = match state.registry.get_tool("_chat_completion") {
432 Some(pt) => pt,
433 None => {
434 return (
435 StatusCode::SERVICE_UNAVAILABLE,
436 Json(HelpResponse {
437 content: String::new(),
438 error: Some("No _llm.toml manifest found. Proxy help requires a configured LLM provider.".into()),
439 }),
440 );
441 }
442 };
443
444 let api_key = match llm_provider
445 .auth_key_name
446 .as_deref()
447 .and_then(|k| state.keyring.get(k))
448 {
449 Some(key) => key.to_string(),
450 None => {
451 return (
452 StatusCode::SERVICE_UNAVAILABLE,
453 Json(HelpResponse {
454 content: String::new(),
455 error: Some("LLM API key not found in keyring".into()),
456 }),
457 );
458 }
459 };
460
461 let scopes = ScopeConfig::unrestricted();
462 let resolved_skills = skill::resolve_skills(&state.skill_registry, &state.registry, &scopes);
463 let skills_section = if resolved_skills.is_empty() {
464 String::new()
465 } else {
466 format!(
467 "## Available Skills (methodology guides)\n{}",
468 skill::build_skill_context(&resolved_skills)
469 )
470 };
471
472 let system_prompt = if let Some(ref tool_name) = req.tool {
474 match build_scoped_prompt(tool_name, &state.registry, &skills_section) {
476 Some(prompt) => prompt,
477 None => {
478 if state.verbose {
480 eprintln!(
481 "[proxy] /help scope '{}' not found, falling back to unscoped",
482 tool_name
483 );
484 }
485 let all_tools = state.registry.list_public_tools();
486 let tools_context = build_tool_context(&all_tools);
487 HELP_SYSTEM_PROMPT
488 .replace("{tools}", &tools_context)
489 .replace("{skills_section}", &skills_section)
490 }
491 }
492 } else {
493 let all_tools = state.registry.list_public_tools();
494 let tools_context = build_tool_context(&all_tools);
495 HELP_SYSTEM_PROMPT
496 .replace("{tools}", &tools_context)
497 .replace("{skills_section}", &skills_section)
498 };
499
500 let request_body = serde_json::json!({
501 "model": "zai-glm-4.7",
502 "messages": [
503 {"role": "system", "content": system_prompt},
504 {"role": "user", "content": req.query}
505 ],
506 "max_completion_tokens": 1536,
507 "temperature": 0.3
508 });
509
510 let client = reqwest::Client::new();
511 let url = format!(
512 "{}{}",
513 llm_provider.base_url.trim_end_matches('/'),
514 llm_tool.endpoint
515 );
516
517 let response = match client
518 .post(&url)
519 .bearer_auth(&api_key)
520 .json(&request_body)
521 .send()
522 .await
523 {
524 Ok(r) => r,
525 Err(e) => {
526 return (
527 StatusCode::BAD_GATEWAY,
528 Json(HelpResponse {
529 content: String::new(),
530 error: Some(format!("LLM request failed: {e}")),
531 }),
532 );
533 }
534 };
535
536 if !response.status().is_success() {
537 let status = response.status();
538 let body = response.text().await.unwrap_or_default();
539 return (
540 StatusCode::BAD_GATEWAY,
541 Json(HelpResponse {
542 content: String::new(),
543 error: Some(format!("LLM API error ({status}): {body}")),
544 }),
545 );
546 }
547
548 let body: Value = match response.json().await {
549 Ok(b) => b,
550 Err(e) => {
551 return (
552 StatusCode::INTERNAL_SERVER_ERROR,
553 Json(HelpResponse {
554 content: String::new(),
555 error: Some(format!("Failed to parse LLM response: {e}")),
556 }),
557 );
558 }
559 };
560
561 let content = body
562 .pointer("/choices/0/message/content")
563 .and_then(|c| c.as_str())
564 .unwrap_or("No response from LLM")
565 .to_string();
566
567 (
568 StatusCode::OK,
569 Json(HelpResponse {
570 content,
571 error: None,
572 }),
573 )
574}
575
576async fn handle_health(State(state): State<Arc<ProxyState>>) -> impl IntoResponse {
577 let auth = if state.jwt_config.is_some() {
578 "jwt"
579 } else {
580 "disabled"
581 };
582
583 Json(HealthResponse {
584 status: "ok".into(),
585 version: env!("CARGO_PKG_VERSION").into(),
586 tools: state.registry.list_public_tools().len(),
587 providers: state.registry.list_providers().len(),
588 skills: state.skill_registry.skill_count(),
589 auth: auth.into(),
590 })
591}
592
593async fn handle_jwks(State(state): State<Arc<ProxyState>>) -> impl IntoResponse {
595 match &state.jwks_json {
596 Some(jwks) => (StatusCode::OK, Json(jwks.clone())),
597 None => (
598 StatusCode::NOT_FOUND,
599 Json(serde_json::json!({"error": "JWKS not configured"})),
600 ),
601 }
602}
603
604async fn handle_mcp(
609 State(state): State<Arc<ProxyState>>,
610 Json(msg): Json<Value>,
611) -> impl IntoResponse {
612 let method = msg.get("method").and_then(|m| m.as_str()).unwrap_or("");
613 let id = msg.get("id").cloned();
614
615 if state.verbose {
616 eprintln!("[proxy] /mcp method={method}");
617 }
618
619 match method {
620 "initialize" => {
621 let result = serde_json::json!({
622 "protocolVersion": "2025-03-26",
623 "capabilities": {
624 "tools": { "listChanged": false }
625 },
626 "serverInfo": {
627 "name": "ati-proxy",
628 "version": env!("CARGO_PKG_VERSION")
629 }
630 });
631 jsonrpc_success(id, result)
632 }
633
634 "notifications/initialized" => (StatusCode::ACCEPTED, Json(Value::Null)),
635
636 "tools/list" => {
637 let all_tools = state.registry.list_public_tools();
638 let mcp_tools: Vec<Value> = all_tools
639 .iter()
640 .map(|(_provider, tool)| {
641 serde_json::json!({
642 "name": tool.name,
643 "description": tool.description,
644 "inputSchema": tool.input_schema.clone().unwrap_or(serde_json::json!({
645 "type": "object",
646 "properties": {}
647 }))
648 })
649 })
650 .collect();
651
652 let result = serde_json::json!({
653 "tools": mcp_tools,
654 });
655 jsonrpc_success(id, result)
656 }
657
658 "tools/call" => {
659 let params = msg.get("params").cloned().unwrap_or(Value::Null);
660 let tool_name = params.get("name").and_then(|n| n.as_str()).unwrap_or("");
661 let arguments: HashMap<String, Value> = params
662 .get("arguments")
663 .and_then(|a| serde_json::from_value(a.clone()).ok())
664 .unwrap_or_default();
665
666 if tool_name.is_empty() {
667 return jsonrpc_error(id, -32602, "Missing tool name in params.name");
668 }
669
670 let (provider, _tool) = match state.registry.get_tool(tool_name) {
671 Some(pt) => pt,
672 None => {
673 return jsonrpc_error(id, -32602, &format!("Unknown tool: '{tool_name}'"));
674 }
675 };
676
677 if state.verbose {
678 eprintln!(
679 "[proxy] /mcp tools/call name={tool_name} provider={}",
680 provider.name
681 );
682 }
683
684 let mcp_gen_ctx = GenContext {
685 jwt_sub: "dev".into(),
686 jwt_scope: "*".into(),
687 tool_name: tool_name.to_string(),
688 timestamp: crate::core::jwt::now_secs(),
689 };
690
691 let result = if provider.is_mcp() {
692 mcp_client::execute_with_gen(
693 provider,
694 tool_name,
695 &arguments,
696 &state.keyring,
697 Some(&mcp_gen_ctx),
698 Some(&state.auth_cache),
699 )
700 .await
701 } else if provider.is_cli() {
702 let raw: Vec<String> = arguments
704 .iter()
705 .flat_map(|(k, v)| {
706 let val = match v {
707 Value::String(s) => s.clone(),
708 other => other.to_string(),
709 };
710 vec![format!("--{k}"), val]
711 })
712 .collect();
713 crate::core::cli_executor::execute_with_gen(
714 provider,
715 &raw,
716 &state.keyring,
717 Some(&mcp_gen_ctx),
718 Some(&state.auth_cache),
719 )
720 .await
721 .map_err(|e| mcp_client::McpError::Transport(e.to_string()))
722 } else {
723 match match provider.handler.as_str() {
724 "xai" => {
725 xai::execute_xai_tool(provider, _tool, &arguments, &state.keyring).await
726 }
727 _ => {
728 http::execute_tool_with_gen(
729 provider,
730 _tool,
731 &arguments,
732 &state.keyring,
733 Some(&mcp_gen_ctx),
734 Some(&state.auth_cache),
735 )
736 .await
737 }
738 } {
739 Ok(val) => Ok(val),
740 Err(e) => Err(mcp_client::McpError::Transport(e.to_string())),
741 }
742 };
743
744 match result {
745 Ok(value) => {
746 let text = match &value {
747 Value::String(s) => s.clone(),
748 other => serde_json::to_string_pretty(other).unwrap_or_default(),
749 };
750 let mcp_result = serde_json::json!({
751 "content": [{"type": "text", "text": text}],
752 "isError": false,
753 });
754 jsonrpc_success(id, mcp_result)
755 }
756 Err(e) => {
757 let mcp_result = serde_json::json!({
758 "content": [{"type": "text", "text": format!("Error: {e}")}],
759 "isError": true,
760 });
761 jsonrpc_success(id, mcp_result)
762 }
763 }
764 }
765
766 _ => jsonrpc_error(id, -32601, &format!("Method not found: '{method}'")),
767 }
768}
769
770fn jsonrpc_success(id: Option<Value>, result: Value) -> (StatusCode, Json<Value>) {
771 (
772 StatusCode::OK,
773 Json(serde_json::json!({
774 "jsonrpc": "2.0",
775 "id": id,
776 "result": result,
777 })),
778 )
779}
780
781fn jsonrpc_error(id: Option<Value>, code: i64, message: &str) -> (StatusCode, Json<Value>) {
782 (
783 StatusCode::OK,
784 Json(serde_json::json!({
785 "jsonrpc": "2.0",
786 "id": id,
787 "error": {
788 "code": code,
789 "message": message,
790 }
791 })),
792 )
793}
794
795async fn handle_skills_list(
800 State(state): State<Arc<ProxyState>>,
801 axum::extract::Query(query): axum::extract::Query<SkillsQuery>,
802) -> impl IntoResponse {
803 if state.verbose {
804 eprintln!(
805 "[proxy] GET /skills category={:?} provider={:?} tool={:?} search={:?}",
806 query.category, query.provider, query.tool, query.search
807 );
808 }
809
810 let skills: Vec<&skill::SkillMeta> = if let Some(search_query) = &query.search {
811 state.skill_registry.search(search_query)
812 } else if let Some(cat) = &query.category {
813 state.skill_registry.skills_for_category(cat)
814 } else if let Some(prov) = &query.provider {
815 state.skill_registry.skills_for_provider(prov)
816 } else if let Some(t) = &query.tool {
817 state.skill_registry.skills_for_tool(t)
818 } else {
819 state.skill_registry.list_skills().iter().collect()
820 };
821
822 let json: Vec<Value> = skills
823 .iter()
824 .map(|s| {
825 serde_json::json!({
826 "name": s.name,
827 "version": s.version,
828 "description": s.description,
829 "tools": s.tools,
830 "providers": s.providers,
831 "categories": s.categories,
832 "hint": s.hint,
833 })
834 })
835 .collect();
836
837 (StatusCode::OK, Json(Value::Array(json)))
838}
839
840async fn handle_skill_detail(
841 State(state): State<Arc<ProxyState>>,
842 axum::extract::Path(name): axum::extract::Path<String>,
843 axum::extract::Query(query): axum::extract::Query<SkillDetailQuery>,
844) -> impl IntoResponse {
845 if state.verbose {
846 eprintln!(
847 "[proxy] GET /skills/{name} meta={:?} refs={:?}",
848 query.meta, query.refs
849 );
850 }
851
852 let skill_meta = match state.skill_registry.get_skill(&name) {
853 Some(s) => s,
854 None => {
855 return (
856 StatusCode::NOT_FOUND,
857 Json(serde_json::json!({"error": format!("Skill '{name}' not found")})),
858 );
859 }
860 };
861
862 if query.meta.unwrap_or(false) {
863 return (
864 StatusCode::OK,
865 Json(serde_json::json!({
866 "name": skill_meta.name,
867 "version": skill_meta.version,
868 "description": skill_meta.description,
869 "author": skill_meta.author,
870 "tools": skill_meta.tools,
871 "providers": skill_meta.providers,
872 "categories": skill_meta.categories,
873 "keywords": skill_meta.keywords,
874 "hint": skill_meta.hint,
875 "depends_on": skill_meta.depends_on,
876 "suggests": skill_meta.suggests,
877 "license": skill_meta.license,
878 "compatibility": skill_meta.compatibility,
879 "allowed_tools": skill_meta.allowed_tools,
880 "format": skill_meta.format,
881 })),
882 );
883 }
884
885 let content = match state.skill_registry.read_content(&name) {
886 Ok(c) => c,
887 Err(e) => {
888 return (
889 StatusCode::INTERNAL_SERVER_ERROR,
890 Json(serde_json::json!({"error": format!("Failed to read skill: {e}")})),
891 );
892 }
893 };
894
895 let mut response = serde_json::json!({
896 "name": skill_meta.name,
897 "version": skill_meta.version,
898 "description": skill_meta.description,
899 "content": content,
900 });
901
902 if query.refs.unwrap_or(false) {
903 if let Ok(refs) = state.skill_registry.list_references(&name) {
904 response["references"] = serde_json::json!(refs);
905 }
906 }
907
908 (StatusCode::OK, Json(response))
909}
910
911async fn handle_skills_resolve(
912 State(state): State<Arc<ProxyState>>,
913 Json(req): Json<SkillResolveRequest>,
914) -> impl IntoResponse {
915 if state.verbose {
916 eprintln!("[proxy] POST /skills/resolve scopes={:?}", req.scopes);
917 }
918
919 let scopes = ScopeConfig {
920 scopes: req.scopes,
921 sub: String::new(),
922 expires_at: 0,
923 rate_config: None,
924 };
925
926 let resolved = skill::resolve_skills(&state.skill_registry, &state.registry, &scopes);
927
928 let json: Vec<Value> = resolved
929 .iter()
930 .map(|s| {
931 serde_json::json!({
932 "name": s.name,
933 "version": s.version,
934 "description": s.description,
935 "tools": s.tools,
936 "providers": s.providers,
937 "categories": s.categories,
938 })
939 })
940 .collect();
941
942 (StatusCode::OK, Json(Value::Array(json)))
943}
944
945async fn auth_middleware(
953 State(state): State<Arc<ProxyState>>,
954 mut req: HttpRequest<Body>,
955 next: Next,
956) -> Result<Response, StatusCode> {
957 let path = req.uri().path();
958
959 if path == "/health" || path == "/.well-known/jwks.json" {
961 return Ok(next.run(req).await);
962 }
963
964 let jwt_config = match &state.jwt_config {
966 Some(c) => c,
967 None => return Ok(next.run(req).await),
968 };
969
970 let auth_header = req
972 .headers()
973 .get("authorization")
974 .and_then(|v| v.to_str().ok());
975
976 let token = match auth_header {
977 Some(header) if header.starts_with("Bearer ") => &header[7..],
978 _ => return Err(StatusCode::UNAUTHORIZED),
979 };
980
981 match jwt::validate(token, jwt_config) {
983 Ok(claims) => {
984 if state.verbose {
985 eprintln!(
986 "[proxy] JWT validated: sub={} scopes={}",
987 claims.sub, claims.scope
988 );
989 }
990 req.extensions_mut().insert(claims);
991 Ok(next.run(req).await)
992 }
993 Err(e) => {
994 if state.verbose {
995 eprintln!("[proxy] JWT validation failed: {e}");
996 }
997 Err(StatusCode::UNAUTHORIZED)
998 }
999 }
1000}
1001
1002pub fn build_router(state: Arc<ProxyState>) -> Router {
1006 Router::new()
1007 .route("/call", post(handle_call))
1008 .route("/help", post(handle_help))
1009 .route("/mcp", post(handle_mcp))
1010 .route("/skills", get(handle_skills_list))
1011 .route("/skills/resolve", post(handle_skills_resolve))
1012 .route("/skills/{name}", get(handle_skill_detail))
1013 .route("/health", get(handle_health))
1014 .route("/.well-known/jwks.json", get(handle_jwks))
1015 .layer(middleware::from_fn_with_state(
1016 state.clone(),
1017 auth_middleware,
1018 ))
1019 .with_state(state)
1020}
1021
1022pub async fn run(
1026 port: u16,
1027 bind_addr: Option<String>,
1028 ati_dir: PathBuf,
1029 verbose: bool,
1030 env_keys: bool,
1031) -> Result<(), Box<dyn std::error::Error>> {
1032 let manifests_dir = ati_dir.join("manifests");
1034 let registry = ManifestRegistry::load(&manifests_dir)?;
1035
1036 let tool_count = registry.list_public_tools().len();
1037 let provider_count = registry.list_providers().len();
1038
1039 let keyring_source;
1041 let keyring = if env_keys {
1042 let kr = Keyring::from_env();
1044 let key_names = kr.key_names();
1045 eprintln!(
1046 " Loaded {} API keys from ATI_KEY_* env vars:",
1047 key_names.len()
1048 );
1049 for name in &key_names {
1050 eprintln!(" - {name}");
1051 }
1052 keyring_source = "env-vars (ATI_KEY_*)";
1053 kr
1054 } else {
1055 let keyring_path = ati_dir.join("keyring.enc");
1057 if keyring_path.exists() {
1058 if let Ok(kr) = Keyring::load(&keyring_path) {
1059 keyring_source = "keyring.enc (sealed key)";
1060 kr
1061 } else if let Ok(kr) = Keyring::load_local(&keyring_path, &ati_dir) {
1062 keyring_source = "keyring.enc (persistent key)";
1063 kr
1064 } else {
1065 eprintln!("Warning: keyring.enc exists but could not be decrypted");
1066 keyring_source = "empty (decryption failed)";
1067 Keyring::empty()
1068 }
1069 } else {
1070 let creds_path = ati_dir.join("credentials");
1071 if creds_path.exists() {
1072 match Keyring::load_credentials(&creds_path) {
1073 Ok(kr) => {
1074 keyring_source = "credentials (plaintext)";
1075 kr
1076 }
1077 Err(e) => {
1078 eprintln!("Warning: Failed to load credentials: {e}");
1079 keyring_source = "empty (credentials error)";
1080 Keyring::empty()
1081 }
1082 }
1083 } else {
1084 eprintln!(
1085 "Warning: No keyring.enc or credentials found — running without API keys"
1086 );
1087 eprintln!("Tools requiring authentication will fail.");
1088 keyring_source = "empty (no auth)";
1089 Keyring::empty()
1090 }
1091 }
1092 };
1093
1094 let mcp_providers: Vec<(String, String)> = registry
1096 .list_mcp_providers()
1097 .iter()
1098 .map(|p| (p.name.clone(), p.mcp_transport_type().to_string()))
1099 .collect();
1100 let mcp_count = mcp_providers.len();
1101 let openapi_providers: Vec<String> = registry
1102 .list_openapi_providers()
1103 .iter()
1104 .map(|p| p.name.clone())
1105 .collect();
1106 let openapi_count = openapi_providers.len();
1107
1108 let skills_dir = ati_dir.join("skills");
1110 let skill_registry = SkillRegistry::load(&skills_dir).unwrap_or_else(|e| {
1111 if verbose {
1112 eprintln!("Warning: failed to load skills: {e}");
1113 }
1114 SkillRegistry::load(std::path::Path::new("/nonexistent-fallback")).unwrap()
1115 });
1116 let skill_count = skill_registry.skill_count();
1117
1118 let jwt_config = match jwt::config_from_env() {
1120 Ok(config) => config,
1121 Err(e) => {
1122 eprintln!("Warning: JWT config error: {e}");
1123 None
1124 }
1125 };
1126
1127 let auth_status = if jwt_config.is_some() {
1128 "JWT enabled"
1129 } else {
1130 "DISABLED (no JWT keys configured)"
1131 };
1132
1133 let jwks_json = jwt_config.as_ref().and_then(|config| {
1135 config
1136 .public_key_pem
1137 .as_ref()
1138 .and_then(|pem| jwt::public_key_to_jwks(pem, config.algorithm, "ati-proxy-1").ok())
1139 });
1140
1141 let state = Arc::new(ProxyState {
1142 registry,
1143 skill_registry,
1144 keyring,
1145 verbose,
1146 jwt_config,
1147 jwks_json,
1148 auth_cache: AuthCache::new(),
1149 });
1150
1151 let app = build_router(state);
1152
1153 let addr: SocketAddr = if let Some(ref bind) = bind_addr {
1154 format!("{bind}:{port}").parse()?
1155 } else {
1156 SocketAddr::from(([127, 0, 0, 1], port))
1157 };
1158
1159 eprintln!("ATI proxy server v{}", env!("CARGO_PKG_VERSION"));
1160 eprintln!(" Listening on http://{addr}");
1161 eprintln!(" Auth: {auth_status}");
1162 eprintln!(" ATI dir: {}", ati_dir.display());
1163 eprintln!(" Tools: {tool_count}, Providers: {provider_count} ({mcp_count} MCP, {openapi_count} OpenAPI)");
1164 eprintln!(" Skills: {skill_count}");
1165 eprintln!(" Keyring: {keyring_source}");
1166 eprintln!(" Endpoints: /call, /help, /mcp, /skills, /skills/:name, /skills/resolve, /health, /.well-known/jwks.json");
1167 for (name, transport) in &mcp_providers {
1168 eprintln!(" MCP: {name} ({transport})");
1169 }
1170 for name in &openapi_providers {
1171 eprintln!(" OpenAPI: {name}");
1172 }
1173
1174 let listener = tokio::net::TcpListener::bind(addr).await?;
1175 axum::serve(listener, app).await?;
1176
1177 Ok(())
1178}
1179
1180fn write_proxy_audit(
1182 call_req: &CallRequest,
1183 agent_sub: &str,
1184 duration: std::time::Duration,
1185 error: Option<&str>,
1186) {
1187 let entry = crate::core::audit::AuditEntry {
1188 ts: chrono::Utc::now().to_rfc3339(),
1189 tool: call_req.tool_name.clone(),
1190 args: crate::core::audit::sanitize_args(&call_req.args),
1191 status: if error.is_some() {
1192 crate::core::audit::AuditStatus::Error
1193 } else {
1194 crate::core::audit::AuditStatus::Ok
1195 },
1196 duration_ms: duration.as_millis() as u64,
1197 agent_sub: agent_sub.to_string(),
1198 error: error.map(|s| s.to_string()),
1199 exit_code: None,
1200 };
1201 let _ = crate::core::audit::append(&entry);
1202}
1203
1204const HELP_SYSTEM_PROMPT: &str = r#"You are a helpful assistant for an AI agent that uses external tools via the `ati` CLI.
1207
1208## Available Tools
1209{tools}
1210
1211{skills_section}
1212
1213Answer the agent's question naturally, like a knowledgeable colleague would. Keep it short but useful:
1214
1215- Explain which tools to use and why, with `ati run` commands showing realistic parameter values
1216- If multiple steps are needed, walk through them briefly in order
1217- Mention important gotchas or parameter choices that matter
1218- If skills are relevant, suggest `ati skill show <name>` for the full methodology
1219
1220Keep your answer concise — a few short paragraphs with embedded code blocks. Only recommend tools from the list above."#;
1221
1222fn build_tool_context(
1223 tools: &[(
1224 &crate::core::manifest::Provider,
1225 &crate::core::manifest::Tool,
1226 )],
1227) -> String {
1228 let mut summaries = Vec::new();
1229 for (provider, tool) in tools {
1230 let mut summary = if let Some(cat) = &provider.category {
1231 format!(
1232 "- **{}** (provider: {}, category: {}): {}",
1233 tool.name, provider.name, cat, tool.description
1234 )
1235 } else {
1236 format!(
1237 "- **{}** (provider: {}): {}",
1238 tool.name, provider.name, tool.description
1239 )
1240 };
1241 if !tool.tags.is_empty() {
1242 summary.push_str(&format!("\n Tags: {}", tool.tags.join(", ")));
1243 }
1244 if provider.is_cli() && tool.input_schema.is_none() {
1246 let cmd = provider.cli_command.as_deref().unwrap_or("?");
1247 summary.push_str(&format!(
1248 "\n Usage: `ati run {} -- <args>` (passthrough to `{}`)",
1249 tool.name, cmd
1250 ));
1251 } else if let Some(schema) = &tool.input_schema {
1252 if let Some(props) = schema.get("properties") {
1253 if let Some(obj) = props.as_object() {
1254 let params: Vec<String> = obj
1255 .iter()
1256 .filter(|(_, v)| {
1257 v.get("x-ati-param-location").is_none()
1258 || v.get("description").is_some()
1259 })
1260 .map(|(k, v)| {
1261 let type_str =
1262 v.get("type").and_then(|t| t.as_str()).unwrap_or("string");
1263 let desc = v.get("description").and_then(|d| d.as_str()).unwrap_or("");
1264 format!(" --{k} ({type_str}): {desc}")
1265 })
1266 .collect();
1267 if !params.is_empty() {
1268 summary.push_str("\n Parameters:\n");
1269 summary.push_str(¶ms.join("\n"));
1270 }
1271 }
1272 }
1273 }
1274 summaries.push(summary);
1275 }
1276 summaries.join("\n\n")
1277}
1278
1279fn build_scoped_prompt(
1283 scope_name: &str,
1284 registry: &ManifestRegistry,
1285 skills_section: &str,
1286) -> Option<String> {
1287 if let Some((provider, tool)) = registry.get_tool(scope_name) {
1289 let mut details = format!(
1290 "**Name**: `{}`\n**Provider**: {} (handler: {})\n**Description**: {}\n",
1291 tool.name, provider.name, provider.handler, tool.description
1292 );
1293 if let Some(cat) = &provider.category {
1294 details.push_str(&format!("**Category**: {}\n", cat));
1295 }
1296 if provider.is_cli() {
1297 let cmd = provider.cli_command.as_deref().unwrap_or("?");
1298 details.push_str(&format!(
1299 "\n**Usage**: `ati run {} -- <args>` (passthrough to `{}`)\n",
1300 tool.name, cmd
1301 ));
1302 } else if let Some(schema) = &tool.input_schema {
1303 if let Some(props) = schema.get("properties").and_then(|p| p.as_object()) {
1304 let required: Vec<String> = schema
1305 .get("required")
1306 .and_then(|r| r.as_array())
1307 .map(|arr| {
1308 arr.iter()
1309 .filter_map(|v| v.as_str().map(|s| s.to_string()))
1310 .collect()
1311 })
1312 .unwrap_or_default();
1313 details.push_str("\n**Parameters**:\n");
1314 for (key, val) in props {
1315 let type_str = val.get("type").and_then(|t| t.as_str()).unwrap_or("string");
1316 let desc = val
1317 .get("description")
1318 .and_then(|d| d.as_str())
1319 .unwrap_or("");
1320 let req = if required.contains(key) {
1321 " **(required)**"
1322 } else {
1323 ""
1324 };
1325 details.push_str(&format!("- `--{key}` ({type_str}{req}): {desc}\n"));
1326 }
1327 }
1328 }
1329
1330 let prompt = format!(
1331 "You are an expert assistant for the `{}` tool, accessed via the `ati` CLI.\n\n\
1332 ## Tool Details\n{}\n\n{}\n\n\
1333 Answer the agent's question about this specific tool. Provide exact commands, explain flags and options, and give practical examples. Be concise and actionable.",
1334 tool.name, details, skills_section
1335 );
1336 return Some(prompt);
1337 }
1338
1339 if registry.has_provider(scope_name) {
1341 let tools = registry.tools_by_provider(scope_name);
1342 if tools.is_empty() {
1343 return None;
1344 }
1345 let tools_context = build_tool_context(&tools);
1346 let prompt = format!(
1347 "You are an expert assistant for the `{}` provider's tools, accessed via the `ati` CLI.\n\n\
1348 ## Tools in provider `{}`\n{}\n\n{}\n\n\
1349 Answer the agent's question about these tools. Provide exact `ati run` commands, explain parameters, and give practical examples. Be concise and actionable.",
1350 scope_name, scope_name, tools_context, skills_section
1351 );
1352 return Some(prompt);
1353 }
1354
1355 None
1356}