Skip to main content

codec_rs/
safety_policy.rs

1// SPDX-License-Identifier: MIT
2//! Safety-policy descriptor loading, validation, and discovery.
3//!
4//! Rust twin of `@codecai/web`'s `safety_policy.ts` (slice 1) and
5//! `codecai.safety_policy` (slice 11). Same shapes, same errors, same
6//! canonical JSON form for hashing — a descriptor that hashes to
7//! `sha256:abc…` in any client hashes to the identical digest here.
8//!
9//! Used by clients that received `safety_policy_id` + `safety_policy_hash`
10//! in `READY` and want to fetch and surface what the server is
11//! enforcing. The descriptor is the *sanitized*, publishable shape —
12//! categories, actions, classifier family, summary stats — never the
13//! operator's internal banned token IDs / classifier thresholds /
14//! regex patterns.
15//!
16//! Discovery follows the existing tokenizer-map convention:
17//!
18//!   - `<origin>/.well-known/codec/policies/<id>.json`         (mutable)
19//!   - `<origin>/.well-known/codec/policies/sha256/<hex>.json` (immutable)
20//!
21//! A client that received a hash in `READY` SHOULD prefer the
22//! content-addressed sibling — it's provably immutable and skips the
23//! mutable indirection.
24
25use serde::{Deserialize, Serialize};
26use sha2::{Digest, Sha256};
27
28pub const POLICY_WELL_KNOWN_BASE: &str = "/.well-known/codec/policies";
29
30// ── Types ────────────────────────────────────────────────────────────────────
31
32#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
33#[serde(rename_all = "snake_case")]
34pub enum CategoryAction {
35    Stop,
36    Redact,
37    Regenerate,
38    Flag,
39}
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
42#[serde(rename_all = "snake_case")]
43pub enum ClassifierHost {
44    Server,
45    Client,
46    Both,
47}
48
49#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
50#[serde(rename_all = "snake_case")]
51pub enum EngineFeature {
52    LogitsProcessor,
53    HiddenStates,
54    SamplingChain,
55}
56
57#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
58pub struct Category {
59    pub name: String,
60    pub action: CategoryAction,
61    #[serde(skip_serializing_if = "Option::is_none", default)]
62    pub description: Option<String>,
63}
64
65#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
66pub struct ClassifierBlock {
67    pub family: String,
68    #[serde(skip_serializing_if = "Option::is_none", default)]
69    pub host: Option<ClassifierHost>,
70    #[serde(skip_serializing_if = "Option::is_none", default)]
71    pub requires_engine_features: Option<Vec<EngineFeature>>,
72}
73
74#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
75pub struct RulesSummary {
76    #[serde(skip_serializing_if = "Option::is_none", default)]
77    pub banned_token_id_count: Option<u64>,
78    #[serde(skip_serializing_if = "Option::is_none", default)]
79    pub regex_pattern_count: Option<u64>,
80    #[serde(skip_serializing_if = "Option::is_none", default)]
81    pub grammar_constraint_count: Option<u64>,
82    #[serde(skip_serializing_if = "Option::is_none", default)]
83    pub multi_token_pattern_count: Option<u64>,
84}
85
86#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
87pub struct ClientHooksBlock {
88    #[serde(skip_serializing_if = "Option::is_none", default)]
89    pub prefilter_categories: Option<Vec<String>>,
90    #[serde(skip_serializing_if = "Option::is_none", default)]
91    pub client_classifier_family: Option<String>,
92}
93
94#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
95pub struct PublisherBlock {
96    #[serde(skip_serializing_if = "Option::is_none", default)]
97    pub name: Option<String>,
98    #[serde(skip_serializing_if = "Option::is_none", default)]
99    pub url: Option<String>,
100    #[serde(skip_serializing_if = "Option::is_none", default)]
101    pub contact: Option<String>,
102}
103
104/// The sanitized, publishable safety-policy descriptor.
105///
106/// Matches `spec/safety-policy.schema.json` v1. Field order mirrors the
107/// TS / Python clients so canonical JSON output is byte-identical.
108#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
109pub struct SafetyPolicyDescriptor {
110    pub id: String,
111    pub version: String,
112    pub tokenizers: Vec<String>,
113    pub categories: Vec<Category>,
114    #[serde(skip_serializing_if = "Option::is_none", default)]
115    pub category_registry: Option<String>,
116    pub classifier: ClassifierBlock,
117    #[serde(skip_serializing_if = "Option::is_none", default)]
118    pub rules_summary: Option<RulesSummary>,
119    #[serde(skip_serializing_if = "Option::is_none", default)]
120    pub client_hooks: Option<ClientHooksBlock>,
121    #[serde(skip_serializing_if = "Option::is_none", default)]
122    pub published_at: Option<String>,
123    #[serde(skip_serializing_if = "Option::is_none", default)]
124    pub publisher: Option<PublisherBlock>,
125}
126
127#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
128pub struct SafetyPolicyPointer {
129    pub id: String,
130    pub url: String,
131    pub hash: String,
132    #[serde(skip_serializing_if = "Option::is_none", default)]
133    pub published_at: Option<String>,
134}
135
136// ── Errors ──────────────────────────────────────────────────────────────────
137
138#[derive(Debug, thiserror::Error)]
139pub enum SafetyPolicyError {
140    #[error("SafetyPolicyDescriptor validation failed: {0}")]
141    Validation(String),
142
143    #[error("SafetyPolicyDescriptor parse failed: {0}")]
144    Parse(#[from] serde_json::Error),
145
146    #[error("SafetyPolicyDescriptor hash mismatch.\n  expected: {expected}\n  actual:   {actual}")]
147    HashMismatch { expected: String, actual: String },
148
149    #[error("Invalid policy id {id:?}: {reason}")]
150    InvalidId { id: String, reason: &'static str },
151
152    #[error("Invalid policy hash hex: must be 64-char lowercase hex (got {got:?})")]
153    InvalidHashHex { got: String },
154
155    #[error("Pointer id {got:?} does not match requested id {expected:?}")]
156    PointerIdMismatch { got: String, expected: String },
157
158    #[error("Pointer url must be http(s): got {got:?}")]
159    PointerBadUrl { got: String },
160
161    #[error("Pointer hash must be sha256:<64 hex chars>: got {got:?}")]
162    PointerBadHash { got: String },
163
164    #[error("Inline descriptor id {got:?} does not match requested id {expected:?}")]
165    InlineIdMismatch { got: String, expected: String },
166
167    #[cfg(feature = "http")]
168    #[error("No safety-policy document at {url} (HTTP {status})")]
169    NotFound { url: String, status: u16 },
170
171    #[cfg(feature = "http")]
172    #[error("http error: {0}")]
173    Http(#[from] reqwest::Error),
174}
175
176// ── Validation ──────────────────────────────────────────────────────────────
177//
178// Hand-written shape check matching the TS / Python validators. Run
179// against the parsed serde_json::Value before attempting to deserialize
180// so we get clean error messages rather than serde's terser ones.
181
182/// Documented charset spec mirrored across all client validators.
183/// Kept as a string constant so error messages cite the same regex
184/// the TS / Python clients reference.
185const CATEGORY_NAME_RE: &str = r"^[a-z0-9_-]+$";
186
187fn category_name_ok(s: &str) -> bool {
188    !s.is_empty() && s.chars().all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == '_' || c == '-')
189}
190
191fn id_ok(s: &str) -> bool {
192    !s.is_empty()
193        && s.chars().all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == '_' || c == '-' || c == '.' || c == '/')
194        && !s.contains("..")
195        && !s.starts_with('/')
196        && !s.ends_with('/')
197}
198
199fn hex64_lower_ok(s: &str) -> bool {
200    s.len() == 64 && s.chars().all(|c| c.is_ascii_hexdigit() && (!c.is_ascii_alphabetic() || c.is_ascii_lowercase()))
201}
202
203/// Validate a parsed JSON value as a SafetyPolicyDescriptor candidate.
204///
205/// Use before [`SafetyPolicyDescriptor::from_json`] when you want
206/// human-readable error messages; otherwise serde's deserializer will
207/// surface its own (terser) errors automatically.
208pub fn validate_safety_policy(value: &serde_json::Value) -> Result<(), SafetyPolicyError> {
209    let v = value
210        .as_object()
211        .ok_or_else(|| SafetyPolicyError::Validation("not an object".into()))?;
212
213    let id = v.get("id").and_then(|x| x.as_str()).filter(|s| !s.is_empty())
214        .ok_or_else(|| SafetyPolicyError::Validation("id must be a non-empty string".into()))?;
215    let _ = id;
216
217    v.get("version").and_then(|x| x.as_str())
218        .ok_or_else(|| SafetyPolicyError::Validation("version must be a string".into()))?;
219
220    let tokenizers = v.get("tokenizers").and_then(|x| x.as_array())
221        .filter(|a| !a.is_empty())
222        .ok_or_else(|| SafetyPolicyError::Validation(
223            "tokenizers must be a non-empty array of tokenizer ids".into(),
224        ))?;
225    for t in tokenizers {
226        if !t.is_string() {
227            return Err(SafetyPolicyError::Validation(
228                "tokenizers entries must be strings".into(),
229            ));
230        }
231    }
232
233    let categories = v.get("categories").and_then(|x| x.as_array())
234        .filter(|a| !a.is_empty())
235        .ok_or_else(|| SafetyPolicyError::Validation(
236            "categories must be a non-empty array".into(),
237        ))?;
238    for c in categories {
239        let cat = c.as_object().ok_or_else(|| {
240            SafetyPolicyError::Validation("category entry must be an object".into())
241        })?;
242        let name = cat.get("name").and_then(|x| x.as_str()).ok_or_else(|| {
243            SafetyPolicyError::Validation("category.name must be a string".into())
244        })?;
245        if !category_name_ok(name) {
246            return Err(SafetyPolicyError::Validation(format!(
247                "category.name must match {CATEGORY_NAME_RE} (got {name:?})"
248            )));
249        }
250        let action = cat.get("action").and_then(|x| x.as_str()).ok_or_else(|| {
251            SafetyPolicyError::Validation(format!(
252                "category.action for {name:?} must be one of stop|redact|regenerate|flag"
253            ))
254        })?;
255        if !matches!(action, "stop" | "redact" | "regenerate" | "flag") {
256            return Err(SafetyPolicyError::Validation(format!(
257                "category.action for {name:?} must be one of stop|redact|regenerate|flag"
258            )));
259        }
260        if let Some(desc) = cat.get("description") {
261            if !desc.is_string() && !desc.is_null() {
262                return Err(SafetyPolicyError::Validation(format!(
263                    "category.description for {name:?} must be a string when present"
264                )));
265            }
266        }
267    }
268
269    let cls = v.get("classifier").and_then(|x| x.as_object()).ok_or_else(|| {
270        SafetyPolicyError::Validation("classifier must be an object".into())
271    })?;
272    let family = cls.get("family").and_then(|x| x.as_str()).filter(|s| !s.is_empty())
273        .ok_or_else(|| SafetyPolicyError::Validation(
274            "classifier.family must be a non-empty string".into(),
275        ))?;
276    let _ = family;
277    if let Some(host) = cls.get("host") {
278        if !host.is_null() {
279            let h = host.as_str().ok_or_else(|| {
280                SafetyPolicyError::Validation(format!(
281                    "classifier.host must be one of server|client|both (got {host})"
282                ))
283            })?;
284            if !matches!(h, "server" | "client" | "both") {
285                return Err(SafetyPolicyError::Validation(format!(
286                    "classifier.host must be one of server|client|both (got {h:?})"
287                )));
288            }
289        }
290    }
291    if let Some(feats) = cls.get("requires_engine_features") {
292        if !feats.is_null() {
293            let arr = feats.as_array().ok_or_else(|| {
294                SafetyPolicyError::Validation(
295                    "classifier.requires_engine_features must be an array".into(),
296                )
297            })?;
298            for f in arr {
299                let s = f.as_str().ok_or_else(|| {
300                    SafetyPolicyError::Validation(
301                        "classifier.requires_engine_features entry must be a string".into(),
302                    )
303                })?;
304                if !matches!(s, "logits_processor" | "hidden_states" | "sampling_chain") {
305                    return Err(SafetyPolicyError::Validation(format!(
306                        "classifier.requires_engine_features entry must be one of \
307                         logits_processor|hidden_states|sampling_chain (got {s:?})"
308                    )));
309                }
310            }
311        }
312    }
313
314    if let Some(rs) = v.get("rules_summary") {
315        if !rs.is_null() {
316            let m = rs.as_object().ok_or_else(|| {
317                SafetyPolicyError::Validation("rules_summary must be an object when present".into())
318            })?;
319            for k in [
320                "banned_token_id_count",
321                "regex_pattern_count",
322                "grammar_constraint_count",
323                "multi_token_pattern_count",
324            ] {
325                if let Some(val) = m.get(k) {
326                    if !val.is_null() && !val.as_u64().is_some() {
327                        return Err(SafetyPolicyError::Validation(format!(
328                            "rules_summary.{k} must be a non-negative integer when present"
329                        )));
330                    }
331                }
332            }
333        }
334    }
335
336    Ok(())
337}
338
339// ── Builder ─────────────────────────────────────────────────────────────────
340
341impl SafetyPolicyDescriptor {
342    /// Parse + validate a JSON byte slice into a SafetyPolicyDescriptor.
343    ///
344    /// Validates first via [`validate_safety_policy`] for clean error
345    /// messages, then deserializes via serde. Rejects shapes the
346    /// validator accepts but serde would silently coerce (e.g. extra
347    /// fields are ignored on the rust side which is the desired
348    /// forward-compat behavior).
349    pub fn from_json(bytes: &[u8]) -> Result<Self, SafetyPolicyError> {
350        let parsed: serde_json::Value = serde_json::from_slice(bytes)?;
351        validate_safety_policy(&parsed)?;
352        let descriptor: SafetyPolicyDescriptor = serde_json::from_value(parsed)?;
353        Ok(descriptor)
354    }
355
356    /// Canonical JSON serialization for hashing + well-known publish.
357    ///
358    /// Matches the TS / Python / supervisor format: 2-space indent +
359    /// trailing newline. Fields with `None` values are omitted (via
360    /// serde's `skip_serializing_if`) so the canonical bytes match
361    /// across stacks.
362    pub fn canonical_bytes(&self) -> Result<Vec<u8>, SafetyPolicyError> {
363        let mut buf = Vec::new();
364        let formatter = serde_json::ser::PrettyFormatter::with_indent(b"  ");
365        let mut ser = serde_json::Serializer::with_formatter(&mut buf, formatter);
366        self.serialize(&mut ser)?;
367        buf.push(b'\n');
368        Ok(buf)
369    }
370
371    /// Canonical sha256 hash of a descriptor.
372    ///
373    /// Returns `sha256:<64 hex chars>` matching what
374    /// `codecai-maps policies hash` emits and what servers should
375    /// publish in `READY.safety_policy_hash`.
376    pub fn hash(&self) -> Result<String, SafetyPolicyError> {
377        let bytes = self.canonical_bytes()?;
378        let mut h = Sha256::new();
379        h.update(&bytes);
380        Ok(format!("sha256:{:x}", h.finalize()))
381    }
382}
383
384fn parse_hash(hash: &str) -> Result<String, SafetyPolicyError> {
385    if let Some(rest) = hash.strip_prefix("sha256:") {
386        let lower = rest.to_ascii_lowercase();
387        if !hex64_lower_ok(&lower) {
388            return Err(SafetyPolicyError::InvalidHashHex { got: hash.to_string() });
389        }
390        Ok(lower)
391    } else {
392        let lower = hash.to_ascii_lowercase();
393        if !hex64_lower_ok(&lower) {
394            return Err(SafetyPolicyError::InvalidHashHex { got: hash.to_string() });
395        }
396        Ok(lower)
397    }
398}
399
400// ── URL builders ────────────────────────────────────────────────────────────
401
402fn strip_trailing_slash(s: &str) -> &str {
403    s.strip_suffix('/').unwrap_or(s)
404}
405
406/// Per-policy URL by mutable id (e.g. `acme/strict-v3`).
407pub fn well_known_policy_url(origin: &str, policy_id: &str) -> Result<String, SafetyPolicyError> {
408    if !id_ok(policy_id) {
409        return Err(SafetyPolicyError::InvalidId {
410            id: policy_id.to_string(),
411            reason: "must match [a-z0-9._/-]+ and contain no traversal",
412        });
413    }
414    Ok(format!(
415        "{}{POLICY_WELL_KNOWN_BASE}/{}.json",
416        strip_trailing_slash(origin),
417        policy_id,
418    ))
419}
420
421/// Content-addressed URL by sha256 hex (no `sha256:` prefix).
422pub fn well_known_policy_hash_url(origin: &str, hash_hex: &str) -> Result<String, SafetyPolicyError> {
423    let lower = hash_hex.to_ascii_lowercase();
424    if !hex64_lower_ok(&lower) {
425        return Err(SafetyPolicyError::InvalidHashHex { got: hash_hex.to_string() });
426    }
427    Ok(format!(
428        "{}{POLICY_WELL_KNOWN_BASE}/sha256/{}.json",
429        strip_trailing_slash(origin),
430        lower,
431    ))
432}
433
434// ── Pointer detection ───────────────────────────────────────────────────────
435
436fn is_pointer_shape(value: &serde_json::Value) -> bool {
437    let Some(obj) = value.as_object() else { return false; };
438    obj.get("id").is_some_and(|v| v.is_string())
439        && obj.get("url").is_some_and(|v| v.is_string())
440        && obj.get("hash").is_some_and(|v| v.is_string())
441        // Inline descriptors always carry `categories`; pointers never do.
442        && !obj.contains_key("categories")
443}
444
445fn validate_pointer(
446    value: &serde_json::Value,
447    expected_id: &str,
448) -> Result<SafetyPolicyPointer, SafetyPolicyError> {
449    let pointer: SafetyPolicyPointer = serde_json::from_value(value.clone())?;
450    if pointer.id != expected_id {
451        return Err(SafetyPolicyError::PointerIdMismatch {
452            got: pointer.id,
453            expected: expected_id.to_string(),
454        });
455    }
456    if !(pointer.url.starts_with("https://") || pointer.url.starts_with("http://")) {
457        return Err(SafetyPolicyError::PointerBadUrl { got: pointer.url });
458    }
459    if !pointer.hash.starts_with("sha256:") || !hex64_lower_ok(&pointer.hash[7..].to_ascii_lowercase()) {
460        return Err(SafetyPolicyError::PointerBadHash { got: pointer.hash });
461    }
462    Ok(pointer)
463}
464
465// ── HTTP loader + discovery (gated on `http` feature) ───────────────────────
466
467#[cfg(feature = "http")]
468mod http_impl {
469    use super::*;
470
471    fn build_async_client() -> Result<reqwest::Client, reqwest::Error> {
472        reqwest::Client::builder()
473            .user_agent("codec-rs/0.1")
474            .gzip(true)
475            .brotli(true)
476            .build()
477    }
478
479    /// Fetch + verify + cache a safety-policy descriptor.
480    ///
481    /// If `hash` is provided, the fetched bytes MUST hash to it
482    /// (returns [`SafetyPolicyError::HashMismatch`] otherwise).
483    pub async fn load_safety_policy(
484        url: &str,
485        hash: Option<&str>,
486    ) -> Result<SafetyPolicyDescriptor, SafetyPolicyError> {
487        let client = build_async_client()?;
488        let resp = client.get(url).send().await?.error_for_status()?;
489        let bytes = resp.bytes().await?;
490
491        if let Some(expected) = hash {
492            let want = parse_hash(expected)?;
493            let mut h = Sha256::new();
494            h.update(&bytes);
495            let actual = format!("{:x}", h.finalize());
496            if actual != want {
497                return Err(SafetyPolicyError::HashMismatch { expected: want, actual });
498            }
499        }
500
501        SafetyPolicyDescriptor::from_json(&bytes)
502    }
503
504    /// Resolve a safety-policy descriptor via `.well-known/codec/policies/`.
505    ///
506    /// If `hash` is provided, fetches the immutable content-addressed
507    /// sibling at `<origin>/.well-known/codec/policies/sha256/<hex>.json`
508    /// and verifies the bytes match. Otherwise fetches the mutable
509    /// per-id document and follows a pointer if present.
510    pub async fn discover_safety_policy(
511        origin: &str,
512        id: &str,
513        hash: Option<&str>,
514    ) -> Result<SafetyPolicyDescriptor, SafetyPolicyError> {
515        let client = build_async_client()?;
516
517        if let Some(h) = hash {
518            let hash_hex = parse_hash(h)?;
519            let url = well_known_policy_hash_url(origin, &hash_hex)?;
520            let resp = client.get(&url).send().await?;
521            if resp.status() == reqwest::StatusCode::NOT_FOUND {
522                return Err(SafetyPolicyError::NotFound {
523                    url,
524                    status: resp.status().as_u16(),
525                });
526            }
527            let resp = resp.error_for_status()?;
528            let bytes = resp.bytes().await?;
529            let mut hasher = Sha256::new();
530            hasher.update(&bytes);
531            let actual = format!("{:x}", hasher.finalize());
532            if actual != hash_hex {
533                return Err(SafetyPolicyError::HashMismatch {
534                    expected: hash_hex,
535                    actual,
536                });
537            }
538            let parsed: serde_json::Value = serde_json::from_slice(&bytes)?;
539            if is_pointer_shape(&parsed) {
540                let pointer = validate_pointer(&parsed, id)?;
541                return load_safety_policy(&pointer.url, Some(&pointer.hash)).await;
542            }
543            let descriptor = SafetyPolicyDescriptor::from_json(&bytes)?;
544            if descriptor.id != id {
545                return Err(SafetyPolicyError::InlineIdMismatch {
546                    got: descriptor.id,
547                    expected: id.to_string(),
548                });
549            }
550            return Ok(descriptor);
551        }
552
553        let url = well_known_policy_url(origin, id)?;
554        let resp = client.get(&url).send().await?;
555        if resp.status() == reqwest::StatusCode::NOT_FOUND {
556            return Err(SafetyPolicyError::NotFound {
557                url,
558                status: resp.status().as_u16(),
559            });
560        }
561        let resp = resp.error_for_status()?;
562        let bytes = resp.bytes().await?;
563        let parsed: serde_json::Value = serde_json::from_slice(&bytes)?;
564        if is_pointer_shape(&parsed) {
565            let pointer = validate_pointer(&parsed, id)?;
566            return load_safety_policy(&pointer.url, Some(&pointer.hash)).await;
567        }
568        let descriptor = SafetyPolicyDescriptor::from_json(&bytes)?;
569        if descriptor.id != id {
570            return Err(SafetyPolicyError::InlineIdMismatch {
571                got: descriptor.id,
572                expected: id.to_string(),
573            });
574        }
575        Ok(descriptor)
576    }
577}
578
579#[cfg(feature = "http")]
580pub use http_impl::{discover_safety_policy, load_safety_policy};
581
582#[cfg(test)]
583mod tests {
584    use super::*;
585
586    fn valid_json() -> serde_json::Value {
587        serde_json::json!({
588            "id": "acme/strict-v3",
589            "version": "1",
590            "tokenizers": ["meta-llama/llama-3"],
591            "categories": [
592                {"name": "secrets", "action": "stop"},
593                {"name": "pii", "action": "redact", "description": "Email and phone."},
594            ],
595            "classifier": {
596                "family": "llama-guard-3-1b",
597                "host": "server",
598                "requires_engine_features": ["logits_processor", "sampling_chain"],
599            },
600            "rules_summary": {
601                "banned_token_id_count": 4128,
602                "regex_pattern_count": 47,
603            },
604            "client_hooks": {
605                "prefilter_categories": ["secrets", "pii"],
606                "client_classifier_family": "prompt-guard-86m",
607            },
608            "published_at": "2026-05-09T00:00:00Z",
609        })
610    }
611
612    fn valid_descriptor() -> SafetyPolicyDescriptor {
613        let bytes = serde_json::to_vec(&valid_json()).unwrap();
614        SafetyPolicyDescriptor::from_json(&bytes).unwrap()
615    }
616
617    // ── Validation ─────────────────────────────────────────────────────────
618
619    #[test]
620    fn validate_accepts_minimal_valid_descriptor() {
621        validate_safety_policy(&valid_json()).unwrap();
622    }
623
624    #[test]
625    fn validate_rejects_missing_required_fields() {
626        validate_safety_policy(&serde_json::json!({})).unwrap_err();
627        let mut bad = valid_json();
628        bad["id"] = serde_json::Value::String(String::new());
629        validate_safety_policy(&bad).unwrap_err();
630        let mut bad = valid_json();
631        bad["tokenizers"] = serde_json::json!([]);
632        validate_safety_policy(&bad).unwrap_err();
633        let mut bad = valid_json();
634        bad["categories"] = serde_json::json!([]);
635        validate_safety_policy(&bad).unwrap_err();
636    }
637
638    #[test]
639    fn validate_rejects_bad_category_name() {
640        let mut bad = valid_json();
641        bad["categories"] = serde_json::json!([{"name": "BadCaps", "action": "stop"}]);
642        validate_safety_policy(&bad).unwrap_err();
643    }
644
645    #[test]
646    fn validate_rejects_unknown_action() {
647        let mut bad = valid_json();
648        bad["categories"] = serde_json::json!([{"name": "secrets", "action": "banhammer"}]);
649        validate_safety_policy(&bad).unwrap_err();
650    }
651
652    #[test]
653    fn validate_rejects_unknown_engine_features() {
654        let mut bad = valid_json();
655        bad["classifier"]["requires_engine_features"] = serde_json::json!(["weather_api"]);
656        validate_safety_policy(&bad).unwrap_err();
657    }
658
659    // ── Hash determinism ───────────────────────────────────────────────────
660
661    #[test]
662    fn hash_is_deterministic_for_identical_input() {
663        let d = valid_descriptor();
664        let a = d.hash().unwrap();
665        let b = d.hash().unwrap();
666        assert_eq!(a, b);
667        assert!(a.starts_with("sha256:"));
668        assert_eq!(a.len() - "sha256:".len(), 64);
669    }
670
671    #[test]
672    fn hash_differs_when_category_action_changes() {
673        let d1 = valid_descriptor();
674        let mut json2 = valid_json();
675        json2["categories"][0]["action"] = serde_json::Value::String("flag".into());
676        let bytes = serde_json::to_vec(&json2).unwrap();
677        let d2 = SafetyPolicyDescriptor::from_json(&bytes).unwrap();
678        assert_ne!(d1.hash().unwrap(), d2.hash().unwrap());
679    }
680
681    #[test]
682    fn canonical_bytes_match_2_space_indent_with_trailing_newline() {
683        let d = valid_descriptor();
684        let raw = d.canonical_bytes().unwrap();
685        let text = std::str::from_utf8(&raw).unwrap();
686        assert!(text.ends_with('\n'));
687        // pretty-printed: contains a newline-then-two-spaces indent.
688        assert!(text.contains("\n  "));
689        // Round-trips through serde.
690        let _: serde_json::Value = serde_json::from_str(text).unwrap();
691    }
692
693    // ── URL builders ───────────────────────────────────────────────────────
694
695    #[test]
696    fn well_known_policy_url_preserves_slashes_and_strips_trailing() {
697        let url = well_known_policy_url("https://acme.example/", "acme/strict-v3").unwrap();
698        assert_eq!(
699            url,
700            "https://acme.example/.well-known/codec/policies/acme/strict-v3.json"
701        );
702    }
703
704    #[test]
705    fn well_known_policy_url_rejects_traversal() {
706        well_known_policy_url("https://acme.example", "../etc").unwrap_err();
707        well_known_policy_url("https://acme.example", "/abs").unwrap_err();
708        well_known_policy_url("https://acme.example", "trailing/").unwrap_err();
709    }
710
711    #[test]
712    fn well_known_policy_url_rejects_bad_charset() {
713        well_known_policy_url("https://acme.example", "Acme/Strict").unwrap_err();
714    }
715
716    #[test]
717    fn well_known_policy_hash_url_uses_sha256_path() {
718        let hex = "a".repeat(64);
719        let url = well_known_policy_hash_url("https://acme.example", &hex).unwrap();
720        assert_eq!(
721            url,
722            format!("https://acme.example/.well-known/codec/policies/sha256/{hex}.json")
723        );
724    }
725
726    #[test]
727    fn well_known_policy_hash_url_rejects_malformed_hex() {
728        well_known_policy_hash_url("https://acme.example", "not-hex").unwrap_err();
729    }
730
731    // ── Round-trip ─────────────────────────────────────────────────────────
732
733    #[test]
734    fn descriptor_round_trip_canonical_bytes_to_json() {
735        let d = valid_descriptor();
736        let raw = d.canonical_bytes().unwrap();
737        let d2 = SafetyPolicyDescriptor::from_json(&raw).unwrap();
738        assert_eq!(d, d2);
739    }
740
741    #[test]
742    fn from_json_rejects_bad_descriptor() {
743        let bytes = b"{}";
744        SafetyPolicyDescriptor::from_json(bytes).unwrap_err();
745    }
746}