Skip to main content

hmd_patch/
lib.rs

1use hmd_core::TomlValueObject;
2use serde::Deserialize;
3use serde_json::Value;
4use sha2::{Digest, Sha256};
5use toml_edit::DocumentMut;
6
7#[derive(Debug, Clone, Deserialize)]
8#[serde(rename_all = "camelCase")]
9pub struct HmdPatch {
10    pub patch_version: String,
11    pub target_hash: Option<String>,
12    pub operations: Vec<PatchOperation>,
13    pub author: Option<String>,
14    pub created: Option<String>,
15}
16
17#[derive(Debug, Clone, Deserialize)]
18#[serde(tag = "op", rename_all = "snake_case")]
19pub enum PatchOperation {
20    SetMeta {
21        target: String,
22        field: String,
23        value: Value,
24    },
25    ReplaceBody {
26        target: String,
27        markdown: String,
28    },
29    AppendBlock {
30        target: String,
31        block: AppendBlock,
32    },
33}
34
35#[derive(Debug, Clone, Deserialize)]
36#[serde(rename_all = "camelCase")]
37pub struct AppendBlock {
38    pub block_type: String,
39    #[serde(default)]
40    pub meta: TomlValueObject,
41    #[serde(default)]
42    pub markdown: String,
43}
44
45#[derive(Debug, Clone, PartialEq, Eq)]
46pub struct PatchError {
47    pub message: String,
48}
49
50impl std::fmt::Display for PatchError {
51    fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52        formatter.write_str(&self.message)
53    }
54}
55
56impl std::error::Error for PatchError {}
57
58pub fn apply_patch(source: &str, patch: &HmdPatch) -> Result<String, PatchError> {
59    if patch.patch_version != "0.1" {
60        return Err(error(format!(
61            "unsupported patch version '{}'",
62            patch.patch_version
63        )));
64    }
65
66    if let Some(target_hash) = &patch.target_hash {
67        let actual = source_hash(source);
68        if target_hash != &actual {
69            return Err(error(format!(
70                "target hash mismatch: expected {target_hash}, got {actual}"
71            )));
72        }
73    }
74
75    validate_operations(source, &patch.operations)?;
76
77    let mut current = source.to_string();
78    for operation in &patch.operations {
79        current = apply_operation(&current, operation)?;
80    }
81    Ok(current)
82}
83
84fn validate_operations(source: &str, operations: &[PatchOperation]) -> Result<(), PatchError> {
85    let document = hmd_parse::parse_document(source);
86
87    for operation in operations {
88        match operation {
89            PatchOperation::SetMeta { target, .. } | PatchOperation::ReplaceBody { target, .. } => {
90                validate_block_target(&document, target)?;
91            }
92            PatchOperation::AppendBlock { target, .. } => {
93                if target != "/document" {
94                    validate_block_target(&document, target)?;
95                }
96            }
97        }
98    }
99
100    Ok(())
101}
102
103fn validate_block_target(document: &hmd_core::HmdDocument, target: &str) -> Result<(), PatchError> {
104    let Some(id) = target.strip_prefix("/blocks/") else {
105        return Err(error(format!(
106            "patch target '{target}' must use /blocks/<id>"
107        )));
108    };
109
110    if document
111        .references
112        .duplicates
113        .iter()
114        .any(|record| record.id == id)
115    {
116        return Err(error(format!("patch target '{target}' is ambiguous")));
117    }
118
119    if !document.references.ids.contains_key(id) {
120        return Err(error(format!("patch target '{target}' was not found")));
121    }
122
123    Ok(())
124}
125
126fn apply_operation(source: &str, operation: &PatchOperation) -> Result<String, PatchError> {
127    match operation {
128        PatchOperation::SetMeta {
129            target,
130            field,
131            value,
132        } => apply_set_meta(source, target, field, value),
133        PatchOperation::ReplaceBody { target, markdown } => {
134            apply_replace_body(source, target, markdown)
135        }
136        PatchOperation::AppendBlock { target, block } => apply_append_block(source, target, block),
137    }
138}
139
140fn apply_set_meta(
141    source: &str,
142    target: &str,
143    field: &str,
144    value: &Value,
145) -> Result<String, PatchError> {
146    let span = find_target_span(source, target)?;
147    let value_source = toml_literal(value)?;
148    let mut output = String::new();
149
150    if let Some(meta) = &span.meta {
151        let updated = set_meta_field(&source[meta.content_start..meta.content_end], field, value)?;
152        output.push_str(&source[..meta.content_start]);
153        output.push_str(&updated);
154        output.push_str(&source[meta.content_end..]);
155    } else {
156        output.push_str(&source[..span.opener_end]);
157        output.push_str("+++\n");
158        output.push_str(field);
159        output.push_str(" = ");
160        output.push_str(&value_source);
161        output.push('\n');
162        output.push_str("+++\n");
163        output.push_str(&source[span.opener_end..]);
164    }
165
166    Ok(output)
167}
168
169fn apply_replace_body(source: &str, target: &str, markdown: &str) -> Result<String, PatchError> {
170    let span = find_target_span(source, target)?;
171    let mut output = String::new();
172    output.push_str(&source[..span.body_start]);
173    output.push_str(&body_replacement(markdown));
174    output.push_str(&source[span.close_start..]);
175    Ok(output)
176}
177
178fn apply_append_block(
179    source: &str,
180    target: &str,
181    block: &AppendBlock,
182) -> Result<String, PatchError> {
183    let block_source = block_to_source(block);
184    let (insert_at, insertion) = if target == "/document" {
185        let prefix = if source.ends_with('\n') { "\n" } else { "\n\n" };
186        (source.len(), format!("{prefix}{block_source}"))
187    } else {
188        let span = find_target_span(source, target)?;
189        (span.close_start, format!("\n{block_source}\n"))
190    };
191
192    let mut output = String::new();
193    output.push_str(&source[..insert_at]);
194    output.push_str(&insertion);
195    output.push_str(&source[insert_at..]);
196    Ok(output)
197}
198
199fn find_target_span(source: &str, target: &str) -> Result<BlockSpan, PatchError> {
200    let id = target
201        .strip_prefix("/blocks/")
202        .ok_or_else(|| error(format!("patch target '{target}' must use /blocks/<id>")))?;
203    let spans = scan_block_spans(source)?;
204    spans
205        .into_iter()
206        .find(|span| span.id.as_deref() == Some(id))
207        .ok_or_else(|| error(format!("patch target '{target}' was not found")))
208}
209
210fn set_meta_field(source: &str, field: &str, value: &Value) -> Result<String, PatchError> {
211    let value_source = toml_literal(value)?;
212    let replacement = format!("{field} = {value_source}\n");
213    let mut output = String::new();
214    let mut replaced = false;
215
216    for line in split_inclusive_or_once(source) {
217        if is_toml_key_line(line, field) {
218            output.push_str(&replacement);
219            replaced = true;
220        } else {
221            output.push_str(line);
222        }
223    }
224
225    if !replaced {
226        if !output.is_empty() && !output.ends_with('\n') {
227            output.push('\n');
228        }
229        output.push_str(&replacement);
230    }
231
232    output.parse::<DocumentMut>().map_err(|parse_error| {
233        error(format!(
234            "set_meta produced invalid TOML metadata for field '{field}': {parse_error}"
235        ))
236    })?;
237
238    Ok(output)
239}
240
241fn is_toml_key_line(line: &str, field: &str) -> bool {
242    let trimmed = line.trim_start_matches([' ', '\t']);
243    let Some(rest) = trimmed.strip_prefix(field) else {
244        return false;
245    };
246    rest.trim_start_matches([' ', '\t']).starts_with('=')
247}
248
249fn body_replacement(markdown: &str) -> String {
250    let trimmed = markdown.trim_matches('\n');
251    if trimmed.is_empty() {
252        "\n".to_string()
253    } else {
254        format!("\n{trimmed}\n\n")
255    }
256}
257
258fn block_to_source(block: &AppendBlock) -> String {
259    let mut output = String::new();
260    output.push_str(&format!("::: {}\n", block.block_type));
261    output.push_str("+++\n");
262    for (key, value) in &block.meta {
263        output.push_str(key);
264        output.push_str(" = ");
265        output.push_str(&toml_literal(value).expect("JSON value serializes to TOML literal"));
266        output.push('\n');
267    }
268    output.push_str("+++\n");
269    if !block.markdown.trim().is_empty() {
270        output.push('\n');
271        output.push_str(block.markdown.trim_matches('\n'));
272        output.push_str("\n\n");
273    }
274    output.push_str(":::\n");
275    output
276}
277
278fn toml_literal(value: &Value) -> Result<String, PatchError> {
279    match value {
280        Value::Null => Err(error("null is not a valid TOML metadata value")),
281        Value::Bool(value) => Ok(value.to_string()),
282        Value::Number(value) => Ok(value.to_string()),
283        Value::String(value) => serde_json::to_string(value)
284            .map_err(|error| self::error(format!("failed to encode TOML string: {error}"))),
285        Value::Array(values) => {
286            let values = values
287                .iter()
288                .map(toml_literal)
289                .collect::<Result<Vec<_>, _>>()?;
290            Ok(format!("[{}]", values.join(", ")))
291        }
292        Value::Object(values) => {
293            let values = values
294                .iter()
295                .map(|(key, value)| Ok(format!("{key} = {}", toml_literal(value)?)))
296                .collect::<Result<Vec<_>, PatchError>>()?;
297            Ok(format!("{{ {} }}", values.join(", ")))
298        }
299    }
300}
301
302#[derive(Debug, Clone)]
303struct BlockSpan {
304    id: Option<String>,
305    opener_end: usize,
306    body_start: usize,
307    close_start: usize,
308    meta: Option<MetaSpan>,
309}
310
311#[derive(Debug, Clone)]
312struct MetaSpan {
313    content_start: usize,
314    content_end: usize,
315}
316
317#[derive(Debug, Clone)]
318struct OpenBlock {
319    id: Option<String>,
320    fence_length: usize,
321    opener_end: usize,
322    body_start: usize,
323    meta: Option<MetaSpan>,
324}
325
326fn scan_block_spans(source: &str) -> Result<Vec<BlockSpan>, PatchError> {
327    let lines = collect_lines(source);
328    let mut spans = Vec::new();
329    let mut stack: Vec<OpenBlock> = Vec::new();
330    let mut index = body_start_index(&lines);
331
332    while index < lines.len() {
333        let line = &lines[index];
334
335        if let Some(open) = stack.last() {
336            if is_closer(line.content, open.fence_length) {
337                let open = stack.pop().expect("stack was not empty");
338                spans.push(BlockSpan {
339                    id: open.id,
340                    opener_end: open.opener_end,
341                    body_start: open.body_start,
342                    close_start: line.start,
343                    meta: open.meta,
344                });
345                index += 1;
346                continue;
347            }
348        }
349
350        if let Some(opener) = parse_opener(line.content) {
351            let mut open = OpenBlock {
352                id: None,
353                fence_length: opener.fence_length,
354                opener_end: line.end,
355                body_start: line.end,
356                meta: None,
357            };
358            index += 1;
359
360            if lines
361                .get(index)
362                .is_some_and(|line| is_plus_delimiter(line.content))
363            {
364                let meta_open = &lines[index];
365                let Some(meta_close_index) =
366                    lines
367                        .iter()
368                        .enumerate()
369                        .skip(index + 1)
370                        .find_map(|(candidate, line)| {
371                            is_plus_delimiter(line.content).then_some(candidate)
372                        })
373                else {
374                    return Err(error(format!(
375                        "unterminated block metadata at line {}",
376                        meta_open.number
377                    )));
378                };
379                let content_start = lines
380                    .get(index + 1)
381                    .map(|line| line.start)
382                    .unwrap_or(meta_open.end);
383                let content_end = lines[meta_close_index].start;
384                let meta_source = &source[content_start..content_end];
385                open.id = parse_meta_id(meta_source);
386                open.body_start = lines[meta_close_index].end;
387                open.meta = Some(MetaSpan {
388                    content_start,
389                    content_end,
390                });
391                index = meta_close_index + 1;
392            }
393
394            stack.push(open);
395            continue;
396        }
397
398        index += 1;
399    }
400
401    Ok(spans)
402}
403
404fn body_start_index(lines: &[Line<'_>]) -> usize {
405    if !lines
406        .first()
407        .is_some_and(|line| is_plus_delimiter(line.content))
408    {
409        return 0;
410    }
411
412    lines
413        .iter()
414        .enumerate()
415        .skip(1)
416        .find_map(|(index, line)| is_plus_delimiter(line.content).then_some(index + 1))
417        .unwrap_or(lines.len())
418}
419
420fn parse_meta_id(source: &str) -> Option<String> {
421    source.parse::<DocumentMut>().ok().and_then(|document| {
422        document
423            .get("id")
424            .and_then(|value| value.as_str())
425            .map(str::to_string)
426    })
427}
428
429#[derive(Debug, Clone, Copy)]
430struct Line<'a> {
431    content: &'a str,
432    start: usize,
433    end: usize,
434    number: usize,
435}
436
437fn collect_lines(source: &str) -> Vec<Line<'_>> {
438    let mut lines = Vec::new();
439    let mut start = 0;
440    let mut number = 1;
441
442    for raw in source.split_inclusive('\n') {
443        let end = start + raw.len();
444        lines.push(Line {
445            content: strip_line_ending(raw),
446            start,
447            end,
448            number,
449        });
450        start = end;
451        number += 1;
452    }
453
454    if start < source.len() {
455        let raw = &source[start..];
456        lines.push(Line {
457            content: strip_line_ending(raw),
458            start,
459            end: source.len(),
460            number,
461        });
462    }
463
464    lines
465}
466
467fn strip_line_ending(line: &str) -> &str {
468    let without_lf = line.strip_suffix('\n').unwrap_or(line);
469    without_lf.strip_suffix('\r').unwrap_or(without_lf)
470}
471
472#[derive(Debug, Clone, Copy, PartialEq, Eq)]
473struct Opener<'a> {
474    fence_length: usize,
475    block_type: &'a str,
476}
477
478fn parse_opener(content: &str) -> Option<Opener<'_>> {
479    let trimmed = trim_horizontal(content);
480    let fence_length = trimmed.bytes().take_while(|byte| *byte == b':').count();
481    if fence_length < 3 {
482        return None;
483    }
484
485    let block_type = trim_horizontal(&trimmed[fence_length..]);
486    if block_type.is_empty() || block_type.bytes().all(|byte| byte == b':') {
487        return None;
488    }
489
490    Some(Opener {
491        fence_length,
492        block_type,
493    })
494}
495
496fn is_closer(content: &str, opener_length: usize) -> bool {
497    let trimmed = trim_horizontal(content);
498    let length = trimmed.bytes().take_while(|byte| *byte == b':').count();
499    length >= opener_length && length == trimmed.len()
500}
501
502fn is_plus_delimiter(content: &str) -> bool {
503    trim_horizontal(content) == "+++"
504}
505
506fn trim_horizontal(value: &str) -> &str {
507    value.trim_matches(|ch| ch == ' ' || ch == '\t')
508}
509
510fn split_inclusive_or_once(source: &str) -> Vec<&str> {
511    if source.is_empty() {
512        Vec::new()
513    } else {
514        source.split_inclusive('\n').collect()
515    }
516}
517
518fn source_hash(source: &str) -> String {
519    let mut hasher = Sha256::new();
520    hasher.update(source.as_bytes());
521    format!("sha256-{}", to_hex(&hasher.finalize()))
522}
523
524fn to_hex(bytes: &[u8]) -> String {
525    const HEX: &[u8; 16] = b"0123456789abcdef";
526    let mut output = String::with_capacity(bytes.len() * 2);
527    for byte in bytes {
528        output.push(HEX[(byte >> 4) as usize] as char);
529        output.push(HEX[(byte & 0x0f) as usize] as char);
530    }
531    output
532}
533
534fn error(message: impl Into<String>) -> PatchError {
535    PatchError {
536        message: message.into(),
537    }
538}
539
540#[cfg(test)]
541mod tests {
542    use super::*;
543    use std::fs;
544    use std::path::{Path, PathBuf};
545
546    #[test]
547    fn patches_task_status_todo_to_done() {
548        let source = fs::read_to_string(repo_path("fixtures/valid/todo-basic.hmd")).unwrap();
549        let patch = serde_json::from_value(serde_json::json!({
550            "patchVersion": "0.1",
551            "operations": [{
552                "op": "set_meta",
553                "target": "/blocks/T-parser",
554                "field": "status",
555                "value": "done"
556            }]
557        }))
558        .unwrap();
559
560        let patched = apply_patch(&source, &patch).unwrap();
561        assert!(patched.contains("status = \"done\""));
562        assert!(!patched.contains("status = \"todo\""));
563    }
564
565    #[test]
566    fn appends_decision_choice_block() {
567        let source = fs::read_to_string(repo_path("fixtures/valid/decision-basic.hmd")).unwrap();
568        let patch = serde_json::from_value(serde_json::json!({
569            "patchVersion": "0.1",
570            "operations": [{
571                "op": "append_block",
572                "target": "/blocks/D-runtime",
573                "block": {
574                    "blockType": "choice",
575                    "meta": {
576                        "id": "CH-runtime",
577                        "option": "rust",
578                        "status": "selected"
579                    },
580                    "markdown": "Selected Rust after reviewing the recommendation."
581                }
582            }]
583        }))
584        .unwrap();
585
586        let patched = apply_patch(&source, &patch).unwrap();
587        assert!(patched.contains("::: choice"));
588        assert!(patched.contains("id = \"CH-runtime\""));
589        assert!(patched.contains("Selected Rust after reviewing the recommendation."));
590    }
591
592    #[test]
593    fn missing_target_fails_without_editing() {
594        let source = fs::read_to_string(repo_path("fixtures/valid/todo-basic.hmd")).unwrap();
595        let patch = serde_json::from_value(serde_json::json!({
596            "patchVersion": "0.1",
597            "operations": [{
598                "op": "replace_body",
599                "target": "/blocks/nope",
600                "markdown": "No edit."
601            }]
602        }))
603        .unwrap();
604
605        let error = apply_patch(&source, &patch).unwrap_err();
606        assert!(error.message.contains("was not found"));
607    }
608
609    #[test]
610    fn duplicate_id_target_fails_safely() {
611        let source = fs::read_to_string(repo_path("fixtures/invalid/duplicate-id.hmd")).unwrap();
612        let patch = serde_json::from_value(serde_json::json!({
613            "patchVersion": "0.1",
614            "operations": [{
615                "op": "set_meta",
616                "target": "/blocks/T-dup",
617                "field": "status",
618                "value": "done"
619            }]
620        }))
621        .unwrap();
622
623        let error = apply_patch(&source, &patch).unwrap_err();
624        assert!(error.message.contains("ambiguous"));
625    }
626
627    fn repo_path(path: impl AsRef<Path>) -> PathBuf {
628        Path::new(env!("CARGO_MANIFEST_DIR"))
629            .join("../..")
630            .join(path)
631    }
632}