Skip to main content

agent_launch/
platforms.rs

1//! Load per-platform prompt templates from `prompts/<kind>.md`.
2
3use std::path::{Path, PathBuf};
4
5use regex::Regex;
6use serde::Deserialize;
7use thiserror::Error;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum PlatformKind {
11    Hn,
12    Reddit,
13    X,
14    Mastodon,
15    Linkedin,
16}
17
18impl PlatformKind {
19    pub fn as_str(&self) -> &'static str {
20        match self {
21            PlatformKind::Hn => "hn",
22            PlatformKind::Reddit => "reddit",
23            PlatformKind::X => "x",
24            PlatformKind::Mastodon => "mastodon",
25            PlatformKind::Linkedin => "linkedin",
26        }
27    }
28
29    pub fn parse(s: &str) -> Option<Self> {
30        match s {
31            "hn" => Some(PlatformKind::Hn),
32            "reddit" => Some(PlatformKind::Reddit),
33            "x" => Some(PlatformKind::X),
34            "mastodon" => Some(PlatformKind::Mastodon),
35            "linkedin" => Some(PlatformKind::Linkedin),
36            _ => None,
37        }
38    }
39}
40
41#[derive(Debug, Error)]
42pub enum PlatformError {
43    #[error("unknown platform: {0}")]
44    UnknownKind(String),
45    #[error("platform template not found: {0}")]
46    NotFound(String),
47    #[error("invalid template (missing or malformed frontmatter): {0}")]
48    BadFrontmatter(String),
49    #[error("section \"{0}\" not found in template")]
50    MissingSection(&'static str),
51    #[error("yaml: {0}")]
52    Yaml(#[from] serde_yaml::Error),
53    #[error("io: {0}")]
54    Io(#[from] std::io::Error),
55}
56
57#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
58#[serde(rename_all = "snake_case")]
59pub enum LengthCapField {
60    Body,
61    PerTweet,
62}
63
64#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
65#[serde(rename_all = "snake_case")]
66pub enum OutputFormat {
67    Json,
68    Thread,
69    Text,
70}
71
72#[derive(Debug, Clone, Deserialize)]
73struct Frontmatter {
74    platform: String,
75    length_cap: usize,
76    length_cap_field: LengthCapField,
77    output_format: OutputFormat,
78    #[serde(default)]
79    title_cap: Option<usize>,
80    #[serde(default)]
81    min_tweets: Option<usize>,
82    #[serde(default)]
83    max_tweets: Option<usize>,
84}
85
86#[derive(Debug, Clone)]
87pub struct PlatformTemplate {
88    pub platform: String,
89    pub length_cap: usize,
90    pub length_cap_field: LengthCapField,
91    pub output_format: OutputFormat,
92    pub title_cap: Option<usize>,
93    pub min_tweets: Option<usize>,
94    pub max_tweets: Option<usize>,
95    pub system: String,
96    pub user_template: String,
97    pub anti_examples: String,
98}
99
100pub fn list_platforms() -> Vec<PlatformKind> {
101    vec![
102        PlatformKind::Hn,
103        PlatformKind::Reddit,
104        PlatformKind::X,
105        PlatformKind::Mastodon,
106        PlatformKind::Linkedin,
107    ]
108}
109
110/// Locate the prompt file for a given kind. Walks up from CARGO_MANIFEST_DIR to find `prompts/`.
111fn find_prompt(kind: PlatformKind) -> Result<PathBuf, PlatformError> {
112    let manifest_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
113    let candidate = manifest_dir
114        .join("prompts")
115        .join(format!("{}.md", kind.as_str()));
116    if candidate.is_file() {
117        return Ok(candidate);
118    }
119    // Fall back: walk up parents.
120    let mut p: Option<&Path> = Some(&manifest_dir);
121    while let Some(dir) = p {
122        let c = dir.join("prompts").join(format!("{}.md", kind.as_str()));
123        if c.is_file() {
124            return Ok(c);
125        }
126        p = dir.parent();
127    }
128    Err(PlatformError::NotFound(kind.as_str().into()))
129}
130
131pub fn load_platform_template(kind: &str) -> Result<PlatformTemplate, PlatformError> {
132    let k = PlatformKind::parse(kind).ok_or_else(|| PlatformError::UnknownKind(kind.into()))?;
133    let path = find_prompt(k)?;
134    let raw = std::fs::read_to_string(&path)?;
135    let re = Regex::new(r"(?s)^---\n(.*?)\n---\n(.*)$").unwrap();
136    let caps = re
137        .captures(&raw)
138        .ok_or_else(|| PlatformError::BadFrontmatter(path.display().to_string()))?;
139    let fm_yaml = caps.get(1).unwrap().as_str();
140    let body = caps.get(2).unwrap().as_str();
141    let fm: Frontmatter = serde_yaml::from_str(fm_yaml)?;
142
143    Ok(PlatformTemplate {
144        platform: fm.platform,
145        length_cap: fm.length_cap,
146        length_cap_field: fm.length_cap_field,
147        output_format: fm.output_format,
148        title_cap: fm.title_cap,
149        min_tweets: fm.min_tweets,
150        max_tweets: fm.max_tweets,
151        system: extract_section(body, "System")?,
152        user_template: extract_section(body, "User")?,
153        anti_examples: extract_section(body, "Anti-examples")?,
154    })
155}
156
157fn extract_section(body: &str, heading: &'static str) -> Result<String, PlatformError> {
158    // Find a line `## <heading>` and capture until the next `## ` or EOF.
159    // `regex` crate has no lookahead, so we do this manually.
160    let needle = format!("## {heading}");
161    let lines: Vec<&str> = body.lines().collect();
162    let start = lines
163        .iter()
164        .position(|l| *l == needle)
165        .ok_or(PlatformError::MissingSection(heading))?;
166    let mut end = lines.len();
167    for (i, line) in lines.iter().enumerate().skip(start + 1) {
168        if line.starts_with("## ") {
169            end = i;
170            break;
171        }
172    }
173    let section = lines[start + 1..end].join("\n");
174    Ok(section.trim().to_string())
175}