1use async_trait::async_trait;
4use serde::Serialize;
5use serde_json::{json, Map, Value};
6
7use crate::config::{Platform, Project};
8use crate::context::GatheredContext;
9use crate::platforms::{
10 load_platform_template, LengthCapField, OutputFormat, PlatformError, PlatformTemplate,
11};
12
13pub const MAX_RETRIES: usize = 2;
14pub const MODEL: &str = "claude-opus-4-7";
15
16#[async_trait]
22pub trait AnthropicClient: Send + Sync {
23 async fn create(&self, params: Value) -> Result<Value, DraftError>;
24}
25
26#[derive(Debug, thiserror::Error)]
27pub enum DraftError {
28 #[error("platform template error: {0}")]
29 Template(#[from] PlatformError),
30 #[error("anthropic API error: {0}")]
31 Api(String),
32}
33
34#[derive(Debug, Clone, Serialize)]
35pub struct DraftResult {
36 pub platform: String,
37 pub body: String,
38 pub length: usize,
39 pub length_cap: usize,
40 pub capped: bool,
41 pub retries: usize,
42 #[serde(skip_serializing_if = "Option::is_none")]
43 pub title: Option<String>,
44 #[serde(skip_serializing_if = "Option::is_none")]
45 pub tweet_count: Option<usize>,
46}
47
48#[derive(Debug, Clone)]
49struct Parsed {
50 body: String,
51 title: Option<String>,
52 tweet_count: Option<usize>,
53}
54
55pub async fn draft_one(
56 platform: &Platform,
57 project: &Project,
58 context: &GatheredContext,
59 repo: &str,
60 anthropic: &dyn AnthropicClient,
61) -> Result<DraftResult, DraftError> {
62 let tpl = load_platform_template(platform.kind())?;
63 let system_prompt = format!("{}\n\n## Anti-examples\n{}", tpl.system, tpl.anti_examples);
64 let user_prompt = render_user_prompt(&tpl.user_template, platform, project, context, repo);
65
66 let mut last: Option<Parsed> = None;
67 let mut retries = 0usize;
68
69 for attempt in 0..=MAX_RETRIES {
70 let final_user = if attempt == 0 {
71 user_prompt.clone()
72 } else {
73 let cap_clause = match tpl.length_cap_field {
74 LengthCapField::PerTweet => {
75 format!("each tweet must be ≤ {} chars", tpl.length_cap)
76 }
77 LengthCapField::Body => {
78 format!("the body must be ≤ {} chars", tpl.length_cap)
79 }
80 };
81 let title_clause = match tpl.title_cap {
82 Some(cap) => format!(" Title ≤ {cap} chars."),
83 None => String::new(),
84 };
85 format!(
86 "{user_prompt}\n\nThe previous attempt exceeded the length limit. Rewrite shorter; {cap_clause}.{title_clause}"
87 )
88 };
89
90 let mut params = Map::new();
91 params.insert("model".into(), json!(MODEL));
92 params.insert("max_tokens".into(), json!(4096));
93 params.insert("temperature".into(), json!(0));
94 params.insert("system".into(), json!(system_prompt));
95 params.insert(
96 "messages".into(),
97 json!([{"role": "user", "content": final_user}]),
98 );
99
100 let resp = anthropic.create(Value::Object(params)).await?;
101 let text = extract_text(&resp).trim().to_string();
102 let parsed = parse_output(&tpl.output_format, &text);
103 let ok = length_ok(&parsed, &tpl);
104 retries = attempt;
105 last = Some(parsed.clone());
106
107 if ok {
108 return Ok(make_result(platform.kind(), parsed, &tpl, true, retries));
109 }
110 }
111 let last = last.expect("loop runs at least once");
112 Ok(make_result(platform.kind(), last, &tpl, false, retries))
113}
114
115fn make_result(
116 kind: &str,
117 parsed: Parsed,
118 tpl: &PlatformTemplate,
119 capped: bool,
120 retries: usize,
121) -> DraftResult {
122 let length = parsed.body.chars().count();
123 DraftResult {
124 platform: kind.to_string(),
125 body: parsed.body,
126 length,
127 length_cap: tpl.length_cap,
128 capped,
129 retries,
130 title: parsed.title,
131 tweet_count: parsed.tweet_count,
132 }
133}
134
135fn extract_text(resp: &Value) -> String {
136 let Some(content) = resp.get("content").and_then(|c| c.as_array()) else {
137 return String::new();
138 };
139 let mut out = String::new();
140 for block in content {
141 if block.get("type").and_then(|v| v.as_str()) == Some("text") {
142 if let Some(t) = block.get("text").and_then(|v| v.as_str()) {
143 out.push_str(t);
144 }
145 }
146 }
147 out
148}
149
150fn parse_output(fmt: &OutputFormat, text: &str) -> Parsed {
151 match fmt {
152 OutputFormat::Json => match serde_json::from_str::<Value>(text) {
153 Ok(obj) => Parsed {
154 body: obj
155 .get("body")
156 .and_then(|v| v.as_str())
157 .unwrap_or("")
158 .to_string(),
159 title: Some(
160 obj.get("title")
161 .and_then(|v| v.as_str())
162 .unwrap_or("")
163 .to_string(),
164 ),
165 tweet_count: None,
166 },
167 Err(_) => Parsed {
168 body: text.to_string(),
169 title: None,
170 tweet_count: None,
171 },
172 },
173 OutputFormat::Thread => {
174 let tweets: Vec<String> = text
175 .split("---tweet---")
176 .map(|s| s.trim().to_string())
177 .filter(|s| !s.is_empty())
178 .collect();
179 let count = tweets.len();
180 Parsed {
181 body: tweets.join("\n---tweet---\n"),
182 title: None,
183 tweet_count: Some(count),
184 }
185 }
186 OutputFormat::Text => Parsed {
187 body: text.to_string(),
188 title: None,
189 tweet_count: None,
190 },
191 }
192}
193
194fn length_ok(parsed: &Parsed, tpl: &PlatformTemplate) -> bool {
195 if let (Some(cap), Some(title)) = (tpl.title_cap, parsed.title.as_ref()) {
196 if title.chars().count() > cap {
197 return false;
198 }
199 }
200 match tpl.length_cap_field {
201 LengthCapField::PerTweet => parsed
202 .body
203 .split("---tweet---")
204 .map(|s| s.trim())
205 .all(|t| t.chars().count() <= tpl.length_cap),
206 LengthCapField::Body => parsed.body.chars().count() <= tpl.length_cap,
207 }
208}
209
210fn render_user_prompt(
211 template: &str,
212 platform: &Platform,
213 project: &Project,
214 context: &GatheredContext,
215 repo: &str,
216) -> String {
217 let hooks_lines = project
218 .hooks
219 .iter()
220 .map(|h| format!("- {h}"))
221 .collect::<Vec<_>>()
222 .join("\n");
223 let commits = context.commits.join("\n");
224 let mut flat: Vec<(&str, String)> = vec![
225 ("project.name", project.name.clone()),
226 ("project.oneliner", project.oneliner.clone()),
227 ("project.audience", project.audience.clone()),
228 ("project.hooks", hooks_lines),
229 ("context.changelog", context.changelog.clone()),
230 ("context.readme", context.readme.clone()),
231 ("context.commits", commits),
232 ("context.repo", repo.to_string()),
233 ("version", context.version.clone()),
234 ];
235 if let Platform::Reddit { subreddit } = platform {
236 flat.push(("platform.subreddit", subreddit.clone()));
237 }
238 let mut out = template.to_string();
239 for (k, v) in flat {
240 out = out.replace(&format!("{{{{{k}}}}}"), &v);
241 }
242 out
243}