Skip to main content

atd_cli/
skills.rs

1//! `atd skills sync` — pull skill files from a connected ATD server via
2//! the skills meta-tool convention (`<publisher>:<service>.skills.list/get`)
3//! and write them to per-platform install paths.
4//!
5//! See [`docs/protocol/wire-format.md` §11](../../../docs/protocol/wire-format.md)
6//! for the convention contract and SP-skills-discovery-convention for the
7//! design rationale.
8
9use std::io::Write;
10use std::path::PathBuf;
11
12use atd_protocol::AtdError;
13use atd_sdk::{AtdClient, CallOptions, DiscoverFilter};
14use serde_json::Value;
15
16use crate::cli::{SkillsSyncArgs, SyncTarget};
17
18pub async fn run(
19    client: &AtdClient,
20    args: SkillsSyncArgs,
21    out: &mut impl Write,
22) -> Result<(), AtdError> {
23    let resolved_out_dir = args
24        .out_dir
25        .clone()
26        .or_else(|| args.target.default_out_dir());
27
28    if matches!(args.target, SyncTarget::Stdout) && args.out_dir.is_some() {
29        return Err(AtdError::InvalidArguments {
30            tool_id: "atd:skills.sync".into(),
31            field: "--out-dir".into(),
32            reason: "cannot be combined with --target stdout; pipe instead".into(),
33        });
34    }
35
36    let tools = client.discover(None, DiscoverFilter::default()).await?;
37
38    let list_ids: Vec<String> = tools
39        .iter()
40        .map(|t| t.id.clone())
41        .filter(|id| id.ends_with(".skills.list"))
42        .collect();
43
44    if list_ids.is_empty() {
45        writeln!(
46            out,
47            "no *.skills.list tool found on this server; nothing to sync"
48        )
49        .ok();
50        return Ok(());
51    }
52
53    let mut total_synced = 0usize;
54    let publishers = list_ids.len();
55
56    for list_id in &list_ids {
57        let prefix =
58            list_id
59                .strip_suffix(".skills.list")
60                .ok_or_else(|| AtdError::ProtocolError {
61                    expected: "tool id ending in .skills.list".into(),
62                    got: list_id.clone(),
63                })?;
64        let get_id = format!("{prefix}.skills.get");
65
66        let entries = call_list(client, list_id).await?;
67        let dir_prefix = prefix.replace([':', '.'], "-");
68
69        for entry in &entries {
70            let name = entry.get("name").and_then(Value::as_str).ok_or_else(|| {
71                AtdError::ProtocolError {
72                    expected: "skill summary entry with `name` field".into(),
73                    got: entry.to_string(),
74                }
75            })?;
76
77            let content = call_get(client, &get_id, name).await?;
78            write_skill(
79                args.target,
80                resolved_out_dir.as_ref(),
81                &dir_prefix,
82                name,
83                &content,
84                args.dry_run,
85                out,
86            )?;
87            total_synced += 1;
88        }
89    }
90
91    let dest = resolved_out_dir
92        .as_ref()
93        .map(|p| p.display().to_string())
94        .unwrap_or_else(|| "stdout".into());
95    writeln!(
96        out,
97        "{total_synced} skill(s) synced from {publishers} publisher(s) to {dest}"
98    )
99    .ok();
100    Ok(())
101}
102
103async fn call_list(client: &AtdClient, list_id: &str) -> Result<Vec<Value>, AtdError> {
104    let result = client
105        .call(list_id, serde_json::json!({}), CallOptions::default())
106        .await?;
107    let data = match result {
108        atd_protocol::ToolResult::Success { data, .. } => data,
109        atd_protocol::ToolResult::Error { code, message, .. } => {
110            return Err(AtdError::ToolExecutionFailed {
111                tool_id: list_id.into(),
112                inner: Box::new(std::io::Error::other(format!("[{code}] {message}"))),
113            });
114        }
115    };
116    data.as_array()
117        .cloned()
118        .ok_or_else(|| AtdError::ProtocolError {
119            expected: "Vec<SkillSummary>".into(),
120            got: data.to_string(),
121        })
122}
123
124async fn call_get(client: &AtdClient, get_id: &str, name: &str) -> Result<String, AtdError> {
125    let result = client
126        .call(
127            get_id,
128            serde_json::json!({"name": name}),
129            CallOptions::default(),
130        )
131        .await?;
132    let data = match result {
133        atd_protocol::ToolResult::Success { data, .. } => data,
134        atd_protocol::ToolResult::Error { code, message, .. } => {
135            return Err(AtdError::ToolExecutionFailed {
136                tool_id: get_id.into(),
137                inner: Box::new(std::io::Error::other(format!(
138                    "[{code}] {message} (skill: {name})"
139                ))),
140            });
141        }
142    };
143    data.get("content_md")
144        .and_then(Value::as_str)
145        .map(String::from)
146        .ok_or_else(|| AtdError::ProtocolError {
147            expected: "skills.get response with content_md field".into(),
148            got: data.to_string(),
149        })
150}
151
152fn write_skill(
153    target: SyncTarget,
154    out_dir: Option<&PathBuf>,
155    dir_prefix: &str,
156    name: &str,
157    content: &str,
158    dry_run: bool,
159    out: &mut impl Write,
160) -> Result<(), AtdError> {
161    let safe_name = sanitize_name(name);
162    match target {
163        SyncTarget::Stdout => {
164            writeln!(out, "--- {dir_prefix}-{safe_name} ---").ok();
165            write!(out, "{content}").ok();
166            if !content.ends_with('\n') {
167                writeln!(out).ok();
168            }
169            Ok(())
170        }
171        SyncTarget::Hermes | SyncTarget::ClaudeCode => {
172            let base = out_dir.ok_or_else(|| AtdError::InvalidArguments {
173                tool_id: "atd:skills.sync".into(),
174                field: "--out-dir".into(),
175                reason: "no install dir resolved (HOME unset?); supply --out-dir explicitly".into(),
176            })?;
177            let dir = base.join(format!("{dir_prefix}-{safe_name}"));
178            let path = dir.join("SKILL.md");
179            if dry_run {
180                writeln!(
181                    out,
182                    "[would write] {} ({} bytes)",
183                    path.display(),
184                    content.len()
185                )
186                .ok();
187            } else {
188                std::fs::create_dir_all(&dir).map_err(|e| AtdError::ToolExecutionFailed {
189                    tool_id: "atd:skills.sync".into(),
190                    inner: Box::new(e),
191                })?;
192                std::fs::write(&path, content).map_err(|e| AtdError::ToolExecutionFailed {
193                    tool_id: "atd:skills.sync".into(),
194                    inner: Box::new(e),
195                })?;
196                writeln!(out, "[wrote] {}", path.display()).ok();
197            }
198            Ok(())
199        }
200    }
201}
202
203fn sanitize_name(s: &str) -> String {
204    s.chars()
205        .map(|c| {
206            if c.is_ascii_alphanumeric() || c == '-' || c == '_' || c == '.' {
207                c
208            } else {
209                '_'
210            }
211        })
212        .collect()
213}
214
215#[cfg(test)]
216mod tests {
217    use super::*;
218
219    #[test]
220    fn sanitize_name_strips_unsafe_chars() {
221        assert_eq!(sanitize_name("healthkit-heartrate"), "healthkit-heartrate");
222        assert_eq!(sanitize_name("a/b\\c d"), "a_b_c_d");
223        assert_eq!(sanitize_name("foo.bar_baz-qux"), "foo.bar_baz-qux");
224    }
225
226    #[test]
227    fn dir_prefix_replaces_colon_and_dot() {
228        let prefix = "huawei:hms.healthkit";
229        let normalized = prefix.replace([':', '.'], "-");
230        assert_eq!(normalized, "huawei-hms-healthkit");
231    }
232}