1use std::path::{Path, PathBuf};
4
5use regex::Regex;
6use serde::Deserialize;
7use thiserror::Error;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum PlatformKind {
11 Hn,
12 Reddit,
13 X,
14 Mastodon,
15 Linkedin,
16}
17
18impl PlatformKind {
19 pub fn as_str(&self) -> &'static str {
20 match self {
21 PlatformKind::Hn => "hn",
22 PlatformKind::Reddit => "reddit",
23 PlatformKind::X => "x",
24 PlatformKind::Mastodon => "mastodon",
25 PlatformKind::Linkedin => "linkedin",
26 }
27 }
28
29 pub fn parse(s: &str) -> Option<Self> {
30 match s {
31 "hn" => Some(PlatformKind::Hn),
32 "reddit" => Some(PlatformKind::Reddit),
33 "x" => Some(PlatformKind::X),
34 "mastodon" => Some(PlatformKind::Mastodon),
35 "linkedin" => Some(PlatformKind::Linkedin),
36 _ => None,
37 }
38 }
39}
40
41#[derive(Debug, Error)]
42pub enum PlatformError {
43 #[error("unknown platform: {0}")]
44 UnknownKind(String),
45 #[error("platform template not found: {0}")]
46 NotFound(String),
47 #[error("invalid template (missing or malformed frontmatter): {0}")]
48 BadFrontmatter(String),
49 #[error("section \"{0}\" not found in template")]
50 MissingSection(&'static str),
51 #[error("yaml: {0}")]
52 Yaml(#[from] serde_yaml::Error),
53 #[error("io: {0}")]
54 Io(#[from] std::io::Error),
55}
56
57#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
58#[serde(rename_all = "snake_case")]
59pub enum LengthCapField {
60 Body,
61 PerTweet,
62}
63
64#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
65#[serde(rename_all = "snake_case")]
66pub enum OutputFormat {
67 Json,
68 Thread,
69 Text,
70}
71
72#[derive(Debug, Clone, Deserialize)]
73struct Frontmatter {
74 platform: String,
75 length_cap: usize,
76 length_cap_field: LengthCapField,
77 output_format: OutputFormat,
78 #[serde(default)]
79 title_cap: Option<usize>,
80 #[serde(default)]
81 min_tweets: Option<usize>,
82 #[serde(default)]
83 max_tweets: Option<usize>,
84}
85
86#[derive(Debug, Clone)]
87pub struct PlatformTemplate {
88 pub platform: String,
89 pub length_cap: usize,
90 pub length_cap_field: LengthCapField,
91 pub output_format: OutputFormat,
92 pub title_cap: Option<usize>,
93 pub min_tweets: Option<usize>,
94 pub max_tweets: Option<usize>,
95 pub system: String,
96 pub user_template: String,
97 pub anti_examples: String,
98}
99
100pub fn list_platforms() -> Vec<PlatformKind> {
101 vec![
102 PlatformKind::Hn,
103 PlatformKind::Reddit,
104 PlatformKind::X,
105 PlatformKind::Mastodon,
106 PlatformKind::Linkedin,
107 ]
108}
109
110fn find_prompt(kind: PlatformKind) -> Result<PathBuf, PlatformError> {
112 let manifest_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
113 let candidate = manifest_dir
114 .join("prompts")
115 .join(format!("{}.md", kind.as_str()));
116 if candidate.is_file() {
117 return Ok(candidate);
118 }
119 let mut p: Option<&Path> = Some(&manifest_dir);
121 while let Some(dir) = p {
122 let c = dir.join("prompts").join(format!("{}.md", kind.as_str()));
123 if c.is_file() {
124 return Ok(c);
125 }
126 p = dir.parent();
127 }
128 Err(PlatformError::NotFound(kind.as_str().into()))
129}
130
131pub fn load_platform_template(kind: &str) -> Result<PlatformTemplate, PlatformError> {
132 let k = PlatformKind::parse(kind).ok_or_else(|| PlatformError::UnknownKind(kind.into()))?;
133 let path = find_prompt(k)?;
134 let raw = std::fs::read_to_string(&path)?;
135 let re = Regex::new(r"(?s)^---\n(.*?)\n---\n(.*)$").unwrap();
136 let caps = re
137 .captures(&raw)
138 .ok_or_else(|| PlatformError::BadFrontmatter(path.display().to_string()))?;
139 let fm_yaml = caps.get(1).unwrap().as_str();
140 let body = caps.get(2).unwrap().as_str();
141 let fm: Frontmatter = serde_yaml::from_str(fm_yaml)?;
142
143 Ok(PlatformTemplate {
144 platform: fm.platform,
145 length_cap: fm.length_cap,
146 length_cap_field: fm.length_cap_field,
147 output_format: fm.output_format,
148 title_cap: fm.title_cap,
149 min_tweets: fm.min_tweets,
150 max_tweets: fm.max_tweets,
151 system: extract_section(body, "System")?,
152 user_template: extract_section(body, "User")?,
153 anti_examples: extract_section(body, "Anti-examples")?,
154 })
155}
156
157fn extract_section(body: &str, heading: &'static str) -> Result<String, PlatformError> {
158 let needle = format!("## {heading}");
161 let lines: Vec<&str> = body.lines().collect();
162 let start = lines
163 .iter()
164 .position(|l| *l == needle)
165 .ok_or(PlatformError::MissingSection(heading))?;
166 let mut end = lines.len();
167 for (i, line) in lines.iter().enumerate().skip(start + 1) {
168 if line.starts_with("## ") {
169 end = i;
170 break;
171 }
172 }
173 let section = lines[start + 1..end].join("\n");
174 Ok(section.trim().to_string())
175}