Skip to main content

hmd_patch/
lib.rs

1use hmd_core::TomlValueObject;
2use serde::{Deserialize, Serialize};
3use serde_json::Value;
4use sha2::{Digest, Sha256};
5use toml_edit::DocumentMut;
6
7#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
8#[serde(rename_all = "camelCase")]
9pub struct HmdPatch {
10    pub patch_version: String,
11    pub target_hash: Option<String>,
12    pub operations: Vec<PatchOperation>,
13    #[serde(skip_serializing_if = "Option::is_none")]
14    pub author: Option<String>,
15    #[serde(skip_serializing_if = "Option::is_none")]
16    pub created: Option<String>,
17}
18
19#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
20#[serde(tag = "op", rename_all = "snake_case")]
21pub enum PatchOperation {
22    SetMeta {
23        target: String,
24        field: String,
25        value: Value,
26        #[serde(default, skip_serializing_if = "Option::is_none")]
27        reason: Option<String>,
28        #[serde(default, skip_serializing_if = "Option::is_none")]
29        label: Option<String>,
30        #[serde(default, rename = "createdBy", skip_serializing_if = "Option::is_none")]
31        created_by: Option<String>,
32    },
33    ReplaceBody {
34        target: String,
35        markdown: String,
36        #[serde(default, skip_serializing_if = "Option::is_none")]
37        reason: Option<String>,
38        #[serde(default, skip_serializing_if = "Option::is_none")]
39        label: Option<String>,
40        #[serde(default, rename = "createdBy", skip_serializing_if = "Option::is_none")]
41        created_by: Option<String>,
42    },
43    AppendBlock {
44        target: String,
45        block: AppendBlock,
46        #[serde(default, skip_serializing_if = "Option::is_none")]
47        reason: Option<String>,
48        #[serde(default, skip_serializing_if = "Option::is_none")]
49        label: Option<String>,
50        #[serde(default, rename = "createdBy", skip_serializing_if = "Option::is_none")]
51        created_by: Option<String>,
52    },
53}
54
55#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
56#[serde(rename_all = "camelCase")]
57pub struct AppendBlock {
58    pub block_type: String,
59    #[serde(default)]
60    pub meta: TomlValueObject,
61    #[serde(default)]
62    pub markdown: String,
63}
64
65#[derive(Debug, Clone, PartialEq, Eq)]
66pub struct PatchError {
67    pub message: String,
68}
69
70impl std::fmt::Display for PatchError {
71    fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72        formatter.write_str(&self.message)
73    }
74}
75
76impl std::error::Error for PatchError {}
77
78pub fn apply_patch(source: &str, patch: &HmdPatch) -> Result<String, PatchError> {
79    if patch.patch_version != "0.1" {
80        return Err(error(format!(
81            "unsupported patch version '{}'",
82            patch.patch_version
83        )));
84    }
85
86    if let Some(target_hash) = &patch.target_hash {
87        let actual = source_hash(source);
88        if target_hash != &actual {
89            return Err(error(format!(
90                "target hash mismatch: expected {target_hash}, got {actual}"
91            )));
92        }
93    }
94
95    validate_operations(source, &patch.operations)?;
96
97    let mut current = source.to_string();
98    for operation in &patch.operations {
99        current = apply_operation(&current, operation)?;
100    }
101    validate_no_duplicate_ids(&current)?;
102    Ok(current)
103}
104
105fn validate_operations(source: &str, operations: &[PatchOperation]) -> Result<(), PatchError> {
106    let document = hmd_parse::parse_document(source);
107
108    for operation in operations {
109        match operation {
110            PatchOperation::SetMeta { target, .. } | PatchOperation::ReplaceBody { target, .. } => {
111                validate_block_target(&document, target)?;
112            }
113            PatchOperation::AppendBlock { target, block, .. } => {
114                if target != "/document" {
115                    validate_block_target(&document, target)?;
116                }
117                validate_appended_block_id(&document, block)?;
118            }
119        }
120    }
121
122    Ok(())
123}
124
125fn validate_appended_block_id(
126    document: &hmd_core::HmdDocument,
127    block: &AppendBlock,
128) -> Result<(), PatchError> {
129    let Some(id) = block.meta.get("id").and_then(Value::as_str) else {
130        return Ok(());
131    };
132
133    if document.references.ids.contains_key(id)
134        || document
135            .references
136            .duplicates
137            .iter()
138            .any(|record| record.id == id)
139    {
140        return Err(error(format!(
141            "append_block would create duplicate id '{id}'"
142        )));
143    }
144
145    Ok(())
146}
147
148fn validate_no_duplicate_ids(source: &str) -> Result<(), PatchError> {
149    let document = hmd_parse::parse_document(source);
150    if let Some(record) = document.references.duplicates.first() {
151        return Err(error(format!(
152            "patch result contains duplicate id '{}'",
153            record.id
154        )));
155    }
156    Ok(())
157}
158
159fn validate_block_target(document: &hmd_core::HmdDocument, target: &str) -> Result<(), PatchError> {
160    let Some(id) = target.strip_prefix("/blocks/") else {
161        return Err(error(format!(
162            "patch target '{target}' must use /blocks/<id>"
163        )));
164    };
165
166    if document
167        .references
168        .duplicates
169        .iter()
170        .any(|record| record.id == id)
171    {
172        return Err(error(format!("patch target '{target}' is ambiguous")));
173    }
174
175    if !document.references.ids.contains_key(id) {
176        return Err(error(format!("patch target '{target}' was not found")));
177    }
178
179    Ok(())
180}
181
182fn apply_operation(source: &str, operation: &PatchOperation) -> Result<String, PatchError> {
183    match operation {
184        PatchOperation::SetMeta {
185            target,
186            field,
187            value,
188            ..
189        } => apply_set_meta(source, target, field, value),
190        PatchOperation::ReplaceBody {
191            target, markdown, ..
192        } => apply_replace_body(source, target, markdown),
193        PatchOperation::AppendBlock { target, block, .. } => {
194            apply_append_block(source, target, block)
195        }
196    }
197}
198
199fn apply_set_meta(
200    source: &str,
201    target: &str,
202    field: &str,
203    value: &Value,
204) -> Result<String, PatchError> {
205    let span = find_target_span(source, target)?;
206    let value_source = toml_literal(value)?;
207    let mut output = String::new();
208
209    if let Some(meta) = &span.meta {
210        let updated = set_meta_field(&source[meta.content_start..meta.content_end], field, value)?;
211        output.push_str(&source[..meta.content_start]);
212        output.push_str(&updated);
213        output.push_str(&source[meta.content_end..]);
214    } else {
215        output.push_str(&source[..span.opener_end]);
216        output.push_str("+++\n");
217        output.push_str(field);
218        output.push_str(" = ");
219        output.push_str(&value_source);
220        output.push('\n');
221        output.push_str("+++\n");
222        output.push_str(&source[span.opener_end..]);
223    }
224
225    Ok(output)
226}
227
228fn apply_replace_body(source: &str, target: &str, markdown: &str) -> Result<String, PatchError> {
229    let span = find_target_span(source, target)?;
230    let mut output = String::new();
231    output.push_str(&source[..span.body_start]);
232    output.push_str(&body_replacement(markdown));
233    output.push_str(&source[span.close_start..]);
234    Ok(output)
235}
236
237fn apply_append_block(
238    source: &str,
239    target: &str,
240    block: &AppendBlock,
241) -> Result<String, PatchError> {
242    let block_source = block_to_source(block);
243    let (insert_at, insertion) = if target == "/document" {
244        let prefix = if source.ends_with('\n') { "\n" } else { "\n\n" };
245        (source.len(), format!("{prefix}{block_source}"))
246    } else {
247        let span = find_target_span(source, target)?;
248        (span.close_start, format!("\n{block_source}\n"))
249    };
250
251    let mut output = String::new();
252    output.push_str(&source[..insert_at]);
253    output.push_str(&insertion);
254    output.push_str(&source[insert_at..]);
255    Ok(output)
256}
257
258fn find_target_span(source: &str, target: &str) -> Result<BlockSpan, PatchError> {
259    let id = target
260        .strip_prefix("/blocks/")
261        .ok_or_else(|| error(format!("patch target '{target}' must use /blocks/<id>")))?;
262    let spans = scan_block_spans(source)?;
263    spans
264        .into_iter()
265        .find(|span| span.id.as_deref() == Some(id))
266        .ok_or_else(|| error(format!("patch target '{target}' was not found")))
267}
268
269fn set_meta_field(source: &str, field: &str, value: &Value) -> Result<String, PatchError> {
270    let value_source = toml_literal(value)?;
271    let replacement = format!("{field} = {value_source}\n");
272    let mut output = String::new();
273    let mut replaced = false;
274
275    for line in split_inclusive_or_once(source) {
276        if is_toml_key_line(line, field) {
277            output.push_str(&replacement);
278            replaced = true;
279        } else {
280            output.push_str(line);
281        }
282    }
283
284    if !replaced {
285        if !output.is_empty() && !output.ends_with('\n') {
286            output.push('\n');
287        }
288        output.push_str(&replacement);
289    }
290
291    output.parse::<DocumentMut>().map_err(|parse_error| {
292        error(format!(
293            "set_meta produced invalid TOML metadata for field '{field}': {parse_error}"
294        ))
295    })?;
296
297    Ok(output)
298}
299
300fn is_toml_key_line(line: &str, field: &str) -> bool {
301    let trimmed = line.trim_start_matches([' ', '\t']);
302    let Some(rest) = trimmed.strip_prefix(field) else {
303        return false;
304    };
305    rest.trim_start_matches([' ', '\t']).starts_with('=')
306}
307
308fn body_replacement(markdown: &str) -> String {
309    let trimmed = markdown.trim_matches('\n');
310    if trimmed.is_empty() {
311        "\n".to_string()
312    } else {
313        format!("\n{trimmed}\n\n")
314    }
315}
316
317fn block_to_source(block: &AppendBlock) -> String {
318    let mut output = String::new();
319    output.push_str(&format!("::: {}\n", block.block_type));
320    output.push_str("+++\n");
321    for (key, value) in &block.meta {
322        output.push_str(key);
323        output.push_str(" = ");
324        output.push_str(&toml_literal(value).expect("JSON value serializes to TOML literal"));
325        output.push('\n');
326    }
327    output.push_str("+++\n");
328    if !block.markdown.trim().is_empty() {
329        output.push('\n');
330        output.push_str(block.markdown.trim_matches('\n'));
331        output.push_str("\n\n");
332    }
333    output.push_str(":::\n");
334    output
335}
336
337fn toml_literal(value: &Value) -> Result<String, PatchError> {
338    match value {
339        Value::Null => Err(error("null is not a valid TOML metadata value")),
340        Value::Bool(value) => Ok(value.to_string()),
341        Value::Number(value) => Ok(value.to_string()),
342        Value::String(value) => serde_json::to_string(value)
343            .map_err(|error| self::error(format!("failed to encode TOML string: {error}"))),
344        Value::Array(values) => {
345            let values = values
346                .iter()
347                .map(toml_literal)
348                .collect::<Result<Vec<_>, _>>()?;
349            Ok(format!("[{}]", values.join(", ")))
350        }
351        Value::Object(values) => {
352            let values = values
353                .iter()
354                .map(|(key, value)| Ok(format!("{key} = {}", toml_literal(value)?)))
355                .collect::<Result<Vec<_>, PatchError>>()?;
356            Ok(format!("{{ {} }}", values.join(", ")))
357        }
358    }
359}
360
361#[derive(Debug, Clone)]
362struct BlockSpan {
363    id: Option<String>,
364    opener_end: usize,
365    body_start: usize,
366    close_start: usize,
367    meta: Option<MetaSpan>,
368}
369
370#[derive(Debug, Clone)]
371struct MetaSpan {
372    content_start: usize,
373    content_end: usize,
374}
375
376#[derive(Debug, Clone)]
377struct OpenBlock {
378    id: Option<String>,
379    fence_length: usize,
380    opener_end: usize,
381    body_start: usize,
382    meta: Option<MetaSpan>,
383}
384
385fn scan_block_spans(source: &str) -> Result<Vec<BlockSpan>, PatchError> {
386    let lines = collect_lines(source);
387    let mut spans = Vec::new();
388    let mut stack: Vec<OpenBlock> = Vec::new();
389    let mut index = body_start_index(&lines);
390
391    while index < lines.len() {
392        let line = &lines[index];
393
394        if let Some(open) = stack.last() {
395            if is_closer(line.content, open.fence_length) {
396                let open = stack.pop().expect("stack was not empty");
397                spans.push(BlockSpan {
398                    id: open.id,
399                    opener_end: open.opener_end,
400                    body_start: open.body_start,
401                    close_start: line.start,
402                    meta: open.meta,
403                });
404                index += 1;
405                continue;
406            }
407        }
408
409        if let Some(opener) = parse_opener(line.content) {
410            let mut open = OpenBlock {
411                id: None,
412                fence_length: opener.fence_length,
413                opener_end: line.end,
414                body_start: line.end,
415                meta: None,
416            };
417            index += 1;
418
419            if lines
420                .get(index)
421                .is_some_and(|line| is_plus_delimiter(line.content))
422            {
423                let meta_open = &lines[index];
424                let Some(meta_close_index) =
425                    lines
426                        .iter()
427                        .enumerate()
428                        .skip(index + 1)
429                        .find_map(|(candidate, line)| {
430                            is_plus_delimiter(line.content).then_some(candidate)
431                        })
432                else {
433                    return Err(error(format!(
434                        "unterminated block metadata at line {}",
435                        meta_open.number
436                    )));
437                };
438                let content_start = lines
439                    .get(index + 1)
440                    .map(|line| line.start)
441                    .unwrap_or(meta_open.end);
442                let content_end = lines[meta_close_index].start;
443                let meta_source = &source[content_start..content_end];
444                open.id = parse_meta_id(meta_source);
445                open.body_start = lines[meta_close_index].end;
446                open.meta = Some(MetaSpan {
447                    content_start,
448                    content_end,
449                });
450                index = meta_close_index + 1;
451            }
452
453            stack.push(open);
454            continue;
455        }
456
457        index += 1;
458    }
459
460    Ok(spans)
461}
462
463fn body_start_index(lines: &[Line<'_>]) -> usize {
464    if !lines
465        .first()
466        .is_some_and(|line| is_plus_delimiter(line.content))
467    {
468        return 0;
469    }
470
471    lines
472        .iter()
473        .enumerate()
474        .skip(1)
475        .find_map(|(index, line)| is_plus_delimiter(line.content).then_some(index + 1))
476        .unwrap_or(lines.len())
477}
478
479fn parse_meta_id(source: &str) -> Option<String> {
480    source.parse::<DocumentMut>().ok().and_then(|document| {
481        document
482            .get("id")
483            .and_then(|value| value.as_str())
484            .map(str::to_string)
485    })
486}
487
488#[derive(Debug, Clone, Copy)]
489struct Line<'a> {
490    content: &'a str,
491    start: usize,
492    end: usize,
493    number: usize,
494}
495
496fn collect_lines(source: &str) -> Vec<Line<'_>> {
497    let mut lines = Vec::new();
498    let mut start = 0;
499    let mut number = 1;
500
501    for raw in source.split_inclusive('\n') {
502        let end = start + raw.len();
503        lines.push(Line {
504            content: strip_line_ending(raw),
505            start,
506            end,
507            number,
508        });
509        start = end;
510        number += 1;
511    }
512
513    if start < source.len() {
514        let raw = &source[start..];
515        lines.push(Line {
516            content: strip_line_ending(raw),
517            start,
518            end: source.len(),
519            number,
520        });
521    }
522
523    lines
524}
525
526fn strip_line_ending(line: &str) -> &str {
527    let without_lf = line.strip_suffix('\n').unwrap_or(line);
528    without_lf.strip_suffix('\r').unwrap_or(without_lf)
529}
530
531#[derive(Debug, Clone, Copy, PartialEq, Eq)]
532struct Opener<'a> {
533    fence_length: usize,
534    block_type: &'a str,
535}
536
537fn parse_opener(content: &str) -> Option<Opener<'_>> {
538    let trimmed = trim_horizontal(content);
539    let fence_length = trimmed.bytes().take_while(|byte| *byte == b':').count();
540    if fence_length < 3 {
541        return None;
542    }
543
544    let block_type = trim_horizontal(&trimmed[fence_length..]);
545    if block_type.is_empty() || block_type.bytes().all(|byte| byte == b':') {
546        return None;
547    }
548
549    Some(Opener {
550        fence_length,
551        block_type,
552    })
553}
554
555fn is_closer(content: &str, opener_length: usize) -> bool {
556    let trimmed = trim_horizontal(content);
557    let length = trimmed.bytes().take_while(|byte| *byte == b':').count();
558    length >= opener_length && length == trimmed.len()
559}
560
561fn is_plus_delimiter(content: &str) -> bool {
562    trim_horizontal(content) == "+++"
563}
564
565fn trim_horizontal(value: &str) -> &str {
566    value.trim_matches(|ch| ch == ' ' || ch == '\t')
567}
568
569fn split_inclusive_or_once(source: &str) -> Vec<&str> {
570    if source.is_empty() {
571        Vec::new()
572    } else {
573        source.split_inclusive('\n').collect()
574    }
575}
576
577pub fn source_hash(source: &str) -> String {
578    let mut hasher = Sha256::new();
579    hasher.update(source.as_bytes());
580    format!("sha256-{}", to_hex(&hasher.finalize()))
581}
582
583fn to_hex(bytes: &[u8]) -> String {
584    const HEX: &[u8; 16] = b"0123456789abcdef";
585    let mut output = String::with_capacity(bytes.len() * 2);
586    for byte in bytes {
587        output.push(HEX[(byte >> 4) as usize] as char);
588        output.push(HEX[(byte & 0x0f) as usize] as char);
589    }
590    output
591}
592
593fn error(message: impl Into<String>) -> PatchError {
594    PatchError {
595        message: message.into(),
596    }
597}
598
599#[cfg(test)]
600mod tests {
601    use super::*;
602    use std::fs;
603    use std::path::{Path, PathBuf};
604
605    #[test]
606    fn patches_task_status_todo_to_done() {
607        let source = fs::read_to_string(repo_path("fixtures/valid/todo-basic.hmd")).unwrap();
608        let patch = serde_json::from_value(serde_json::json!({
609            "patchVersion": "0.1",
610            "operations": [{
611                "op": "set_meta",
612                "target": "/blocks/T-parser",
613                "field": "status",
614                "value": "done"
615            }]
616        }))
617        .unwrap();
618
619        let patched = apply_patch(&source, &patch).unwrap();
620        assert!(patched.contains("status = \"done\""));
621        assert!(!patched.contains("status = \"todo\""));
622    }
623
624    #[test]
625    fn appends_decision_choice_block() {
626        let source = fs::read_to_string(repo_path("fixtures/valid/decision-basic.hmd")).unwrap();
627        let patch = serde_json::from_value(serde_json::json!({
628            "patchVersion": "0.1",
629            "operations": [{
630                "op": "append_block",
631                "target": "/blocks/D-runtime",
632                "block": {
633                    "blockType": "choice",
634                    "meta": {
635                        "id": "CH-runtime",
636                        "option": "rust",
637                        "status": "selected"
638                    },
639                    "markdown": "Selected Rust after reviewing the recommendation."
640                }
641            }]
642        }))
643        .unwrap();
644
645        let patched = apply_patch(&source, &patch).unwrap();
646        assert!(patched.contains("::: choice"));
647        assert!(patched.contains("id = \"CH-runtime\""));
648        assert!(patched.contains("Selected Rust after reviewing the recommendation."));
649    }
650
651    #[test]
652    fn missing_target_fails_without_editing() {
653        let source = fs::read_to_string(repo_path("fixtures/valid/todo-basic.hmd")).unwrap();
654        let patch = serde_json::from_value(serde_json::json!({
655            "patchVersion": "0.1",
656            "operations": [{
657                "op": "replace_body",
658                "target": "/blocks/nope",
659                "markdown": "No edit."
660            }]
661        }))
662        .unwrap();
663
664        let error = apply_patch(&source, &patch).unwrap_err();
665        assert!(error.message.contains("was not found"));
666    }
667
668    #[test]
669    fn duplicate_id_target_fails_safely() {
670        let source = fs::read_to_string(repo_path("fixtures/invalid/duplicate-id.hmd")).unwrap();
671        let patch = serde_json::from_value(serde_json::json!({
672            "patchVersion": "0.1",
673            "operations": [{
674                "op": "set_meta",
675                "target": "/blocks/T-dup",
676                "field": "status",
677                "value": "done"
678            }]
679        }))
680        .unwrap();
681
682        let error = apply_patch(&source, &patch).unwrap_err();
683        assert!(error.message.contains("ambiguous"));
684    }
685
686    #[test]
687    fn stale_hash_fails_before_editing() {
688        let source = fs::read_to_string(repo_path("fixtures/valid/decision-basic.hmd")).unwrap();
689        let patch = serde_json::from_value(serde_json::json!({
690            "patchVersion": "0.1",
691            "targetHash": "sha256-stale",
692            "operations": [{
693                "op": "set_meta",
694                "target": "/blocks/D-runtime",
695                "field": "status",
696                "value": "accepted"
697            }]
698        }))
699        .unwrap();
700
701        let error = apply_patch(&source, &patch).unwrap_err();
702        assert!(error.message.contains("target hash mismatch"));
703        assert!(source.contains("status = \"recommended\""));
704    }
705
706    #[test]
707    fn duplicate_id_append_fails_safely() {
708        let source = fs::read_to_string(repo_path("fixtures/valid/decision-basic.hmd")).unwrap();
709        let patch = serde_json::from_value(serde_json::json!({
710            "patchVersion": "0.1",
711            "operations": [{
712                "op": "append_block",
713                "target": "/blocks/D-runtime",
714                "block": {
715                    "blockType": "choice",
716                    "meta": {
717                        "id": "rust",
718                        "option": "rust",
719                        "status": "selected"
720                    },
721                    "markdown": "Duplicate."
722                }
723            }]
724        }))
725        .unwrap();
726
727        let error = apply_patch(&source, &patch).unwrap_err();
728        assert!(error.message.contains("duplicate id 'rust'"));
729    }
730
731    fn repo_path(path: impl AsRef<Path>) -> PathBuf {
732        Path::new(env!("CARGO_MANIFEST_DIR"))
733            .join("../..")
734            .join(path)
735    }
736}