Skip to main content

aft/
harness.rs

1use serde::de::{self, Visitor};
2use serde::{Deserialize, Deserializer, Serialize, Serializer};
3use std::fmt;
4use std::str::FromStr;
5
6#[derive(Debug, Clone, PartialEq, Eq, Hash)]
7pub enum Harness {
8    Opencode,
9    Pi,
10    Runner,
11    Mcp { client: String },
12}
13
14impl Harness {
15    pub fn storage_segment(&self) -> String {
16        match self {
17            Harness::Opencode => "opencode".to_string(),
18            Harness::Pi => "pi".to_string(),
19            Harness::Runner => "runner".to_string(),
20            Harness::Mcp { client } => format!("mcp--{}", sanitize_client(client)),
21        }
22    }
23
24    pub fn wire_label(&self) -> String {
25        match self {
26            Harness::Opencode => "opencode".to_string(),
27            Harness::Pi => "pi".to_string(),
28            Harness::Runner => "runner".to_string(),
29            Harness::Mcp { client } => format!("mcp:{client}"),
30        }
31    }
32}
33
34/// Max length of the readable (pre-hash) slug portion. The full segment is
35/// `mcp--<readable>--<32 hex>`, so the readable part is capped to keep directory
36/// names bounded while the hash guarantees uniqueness.
37const MCP_SLUG_READABLE_MAX: usize = 40;
38const MCP_SLUG_HASH_HEX_LEN: usize = 32;
39
40/// Build the storage slug for an MCP client. The readable portion is a
41/// sanitized, length-capped rendering of the raw client; a short hash of the
42/// RAW (un-sanitized) client is appended so that distinct clients that sanitize
43/// to the same readable string (e.g. `a/b`, `a:b`, `a b`, casing variants, or
44/// non-ASCII that collapses to `unknown`) still get distinct directories. The
45/// hash is over the raw bytes, so it is collision-resistant where the readable
46/// slug is not.
47fn sanitize_client(client: &str) -> String {
48    let lower = client.to_ascii_lowercase();
49    let mut out = String::with_capacity(lower.len());
50    let mut last_was_dash = false;
51    for ch in lower.chars() {
52        let keep = ch.is_ascii_alphanumeric() || matches!(ch, '.' | '_' | '-');
53        if keep {
54            out.push(ch);
55            last_was_dash = false;
56        } else if !last_was_dash {
57            out.push('-');
58            last_was_dash = true;
59        }
60    }
61    let trimmed = out.trim_matches(|c| c == '-' || c == '.');
62    let mut readable = if trimmed.is_empty() {
63        "unknown".to_string()
64    } else {
65        trimmed.to_string()
66    };
67    if readable.len() > MCP_SLUG_READABLE_MAX {
68        readable.truncate(MCP_SLUG_READABLE_MAX);
69        // Truncation can leave a trailing separator; trim it for tidiness.
70        readable = readable.trim_end_matches(['-', '.']).to_string();
71        if readable.is_empty() {
72            readable = "unknown".to_string();
73        }
74    }
75
76    // A 128-bit hash suffix prevents hostile same-readable slugs from sharing
77    // storage while keeping directory names short enough for common filesystems.
78    let hash = blake3::hash(client.as_bytes()).to_hex();
79    format!("{readable}--{}", &hash.as_str()[..MCP_SLUG_HASH_HEX_LEN])
80}
81
82impl Serialize for Harness {
83    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
84    where
85        S: Serializer,
86    {
87        serializer.serialize_str(&self.wire_label())
88    }
89}
90
91impl<'de> Deserialize<'de> for Harness {
92    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
93    where
94        D: Deserializer<'de>,
95    {
96        struct HarnessVisitor;
97
98        impl<'de> Visitor<'de> for HarnessVisitor {
99            type Value = Harness;
100
101            fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
102                formatter
103                    .write_str("a harness string: 'opencode', 'pi', 'runner', or 'mcp:<client>'")
104            }
105
106            fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
107            where
108                E: de::Error,
109            {
110                Harness::from_str(value).map_err(E::custom)
111            }
112        }
113
114        deserializer.deserialize_str(HarnessVisitor)
115    }
116}
117
118impl fmt::Display for Harness {
119    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
120        f.write_str(&self.wire_label())
121    }
122}
123
124impl std::str::FromStr for Harness {
125    type Err = String;
126
127    fn from_str(value: &str) -> Result<Self, Self::Err> {
128        match value {
129            "opencode" => Ok(Harness::Opencode),
130            "pi" => Ok(Harness::Pi),
131            "runner" => Ok(Harness::Runner),
132            other if other.starts_with("mcp:") => {
133                let client = &other[4..];
134                if client.is_empty() {
135                    Err(
136                        "unsupported harness 'mcp:'; mcp client name must be non-empty".to_string(),
137                    )
138                } else {
139                    Ok(Harness::Mcp {
140                        client: client.to_string(),
141                    })
142                }
143            }
144            other => Err(format!(
145                "unsupported harness '{other}'; expected 'opencode', 'pi', 'runner', or 'mcp:<client>'"
146            )),
147        }
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use super::{sanitize_client, Harness};
154    use std::str::FromStr;
155
156    #[test]
157    fn harness_enum_serde_roundtrip() {
158        assert_eq!(
159            serde_json::to_string(&Harness::Opencode).unwrap(),
160            "\"opencode\""
161        );
162        assert_eq!(serde_json::to_string(&Harness::Pi).unwrap(), "\"pi\"");
163
164        assert_eq!(
165            serde_json::from_str::<Harness>("\"opencode\"").unwrap(),
166            Harness::Opencode
167        );
168        assert_eq!(
169            serde_json::from_str::<Harness>("\"pi\"").unwrap(),
170            Harness::Pi
171        );
172        assert!(serde_json::from_str::<Harness>("\"claude_code\"").is_err());
173    }
174
175    #[test]
176    fn opencode_pi_storage_segment_unchanged() {
177        assert_eq!(Harness::Opencode.storage_segment(), "opencode");
178        assert_eq!(Harness::Pi.storage_segment(), "pi");
179    }
180
181    #[test]
182    fn runner_round_trips() {
183        assert_eq!(Harness::from_str("runner").unwrap(), Harness::Runner);
184        assert_eq!(Harness::Runner.storage_segment(), "runner");
185        assert_eq!(
186            serde_json::to_string(&Harness::Runner).unwrap(),
187            "\"runner\""
188        );
189        assert_eq!(
190            serde_json::from_str::<Harness>("\"runner\"").unwrap(),
191            Harness::Runner
192        );
193    }
194
195    #[test]
196    fn mcp_round_trips() {
197        let h = Harness::Mcp {
198            client: "claude-code".to_string(),
199        };
200        assert_eq!(serde_json::to_string(&h).unwrap(), "\"mcp:claude-code\"");
201        assert_eq!(
202            serde_json::from_str::<Harness>("\"mcp:claude-code\"").unwrap(),
203            h
204        );
205        assert_eq!(
206            Harness::from_str("mcp:cursor").unwrap(),
207            Harness::Mcp {
208                client: "cursor".to_string(),
209            }
210        );
211        assert!(Harness::from_str("mcp:").is_err());
212    }
213
214    #[test]
215    fn storage_segment_hostile_clients_are_path_safe() {
216        let cases = ["../../etc", "a/b", r"a\b", "a:b", "", "Claude.Code"];
217        for client in cases {
218            let seg = Harness::Mcp {
219                client: client.to_string(),
220            }
221            .storage_segment();
222            assert!(
223                !seg.is_empty(),
224                "segment must be non-empty for client {client:?}"
225            );
226            assert!(
227                !seg.contains(['/', '\\', ':']),
228                "segment {seg:?} must not contain path separators for client {client:?}"
229            );
230            assert!(
231                !seg.contains(".."),
232                "segment {seg:?} must not contain '..' for client {client:?}"
233            );
234            assert!(
235                seg.starts_with("mcp--"),
236                "segment {seg:?} must use mcp-- prefix"
237            );
238        }
239        // Readable portion preserved, hash suffix appended.
240        let claude = Harness::Mcp {
241            client: "Claude.Code".to_string(),
242        }
243        .storage_segment();
244        assert!(
245            claude.starts_with("mcp--claude.code--"),
246            "expected readable slug with hash suffix, got {claude:?}"
247        );
248        // Empty client → readable "unknown" plus a (stable) hash of empty bytes.
249        let empty = sanitize_client("");
250        assert!(
251            empty.starts_with("unknown--"),
252            "empty client must render unknown-- plus hash, got {empty:?}"
253        );
254    }
255
256    #[test]
257    fn storage_segment_disambiguates_clients_that_sanitize_to_same_slug() {
258        // a/b, a:b, a b, A-B all collapse to the readable slug "a-b" but are
259        // DISTINCT clients — the raw-bytes hash suffix must keep their storage
260        // directories distinct so two different MCP clients never share state.
261        let seg = |c: &str| {
262            Harness::Mcp {
263                client: c.to_string(),
264            }
265            .storage_segment()
266        };
267        let variants = [seg("a/b"), seg("a:b"), seg("a b"), seg("A-B")];
268        for s in &variants {
269            assert!(
270                s.starts_with("mcp--a-b--"),
271                "expected shared readable slug a-b, got {s:?}"
272            );
273            let (_readable, suffix) = s.rsplit_once("--").expect("hash suffix");
274            assert_eq!(
275                suffix.len(),
276                super::MCP_SLUG_HASH_HEX_LEN,
277                "hash suffix must carry 128 bits of disambiguation: {s:?}"
278            );
279            assert!(
280                suffix.chars().all(|ch| ch.is_ascii_hexdigit()),
281                "hash suffix must be hex: {s:?}"
282            );
283        }
284        let unique: std::collections::HashSet<_> = variants.iter().collect();
285        assert_eq!(
286            unique.len(),
287            variants.len(),
288            "distinct clients must get distinct storage segments: {variants:?}"
289        );
290
291        // Same raw client → same segment (deterministic, stable across calls).
292        assert_eq!(seg("cursor"), seg("cursor"));
293
294        // Very long client: readable portion is capped, segment stays bounded.
295        let long = seg(&"x".repeat(500));
296        assert!(
297            long.len()
298                <= "mcp--".len()
299                    + super::MCP_SLUG_READABLE_MAX
300                    + "--".len()
301                    + super::MCP_SLUG_HASH_HEX_LEN,
302            "long client segment must be length-bounded, got len {}",
303            long.len()
304        );
305    }
306}