Skip to main content

agent_launch/
draft.rs

1//! Draft one platform-native release announcement via Claude.
2
3use async_trait::async_trait;
4use serde::Serialize;
5use serde_json::{json, Map, Value};
6
7use crate::config::{Platform, Project};
8use crate::context::GatheredContext;
9use crate::platforms::{
10    load_platform_template, LengthCapField, OutputFormat, PlatformError, PlatformTemplate,
11};
12
13pub const MAX_RETRIES: usize = 2;
14pub const MODEL: &str = "claude-opus-4-7";
15
16/// Tiny adapter over the Anthropic messages API. Used so tests can supply a fake.
17///
18/// `params` is a JSON object with the same keys the SDK accepts:
19/// `model`, `max_tokens`, `temperature`, `system`, `messages`.
20/// Implementations must return a value with `content: [{type:"text", text:"..."}, ...]`.
21#[async_trait]
22pub trait AnthropicClient: Send + Sync {
23    async fn create(&self, params: Value) -> Result<Value, DraftError>;
24}
25
26#[derive(Debug, thiserror::Error)]
27pub enum DraftError {
28    #[error("platform template error: {0}")]
29    Template(#[from] PlatformError),
30    #[error("anthropic API error: {0}")]
31    Api(String),
32}
33
34#[derive(Debug, Clone, Serialize)]
35pub struct DraftResult {
36    pub platform: String,
37    pub body: String,
38    pub length: usize,
39    pub length_cap: usize,
40    pub capped: bool,
41    pub retries: usize,
42    #[serde(skip_serializing_if = "Option::is_none")]
43    pub title: Option<String>,
44    #[serde(skip_serializing_if = "Option::is_none")]
45    pub tweet_count: Option<usize>,
46}
47
48#[derive(Debug, Clone)]
49struct Parsed {
50    body: String,
51    title: Option<String>,
52    tweet_count: Option<usize>,
53}
54
55pub async fn draft_one(
56    platform: &Platform,
57    project: &Project,
58    context: &GatheredContext,
59    repo: &str,
60    anthropic: &dyn AnthropicClient,
61) -> Result<DraftResult, DraftError> {
62    let tpl = load_platform_template(platform.kind())?;
63    let system_prompt = format!("{}\n\n## Anti-examples\n{}", tpl.system, tpl.anti_examples);
64    let user_prompt = render_user_prompt(&tpl.user_template, platform, project, context, repo);
65
66    let mut last: Option<Parsed> = None;
67    let mut retries = 0usize;
68
69    for attempt in 0..=MAX_RETRIES {
70        let final_user = if attempt == 0 {
71            user_prompt.clone()
72        } else {
73            let cap_clause = match tpl.length_cap_field {
74                LengthCapField::PerTweet => {
75                    format!("each tweet must be ≤ {} chars", tpl.length_cap)
76                }
77                LengthCapField::Body => {
78                    format!("the body must be ≤ {} chars", tpl.length_cap)
79                }
80            };
81            let title_clause = match tpl.title_cap {
82                Some(cap) => format!(" Title ≤ {cap} chars."),
83                None => String::new(),
84            };
85            format!(
86                "{user_prompt}\n\nThe previous attempt exceeded the length limit. Rewrite shorter; {cap_clause}.{title_clause}"
87            )
88        };
89
90        let mut params = Map::new();
91        params.insert("model".into(), json!(MODEL));
92        params.insert("max_tokens".into(), json!(4096));
93        params.insert("temperature".into(), json!(0));
94        params.insert("system".into(), json!(system_prompt));
95        params.insert(
96            "messages".into(),
97            json!([{"role": "user", "content": final_user}]),
98        );
99
100        let resp = anthropic.create(Value::Object(params)).await?;
101        let text = extract_text(&resp).trim().to_string();
102        let parsed = parse_output(&tpl.output_format, &text);
103        let ok = length_ok(&parsed, &tpl);
104        retries = attempt;
105        last = Some(parsed.clone());
106
107        if ok {
108            return Ok(make_result(platform.kind(), parsed, &tpl, true, retries));
109        }
110    }
111    let last = last.expect("loop runs at least once");
112    Ok(make_result(platform.kind(), last, &tpl, false, retries))
113}
114
115fn make_result(
116    kind: &str,
117    parsed: Parsed,
118    tpl: &PlatformTemplate,
119    capped: bool,
120    retries: usize,
121) -> DraftResult {
122    let length = parsed.body.chars().count();
123    DraftResult {
124        platform: kind.to_string(),
125        body: parsed.body,
126        length,
127        length_cap: tpl.length_cap,
128        capped,
129        retries,
130        title: parsed.title,
131        tweet_count: parsed.tweet_count,
132    }
133}
134
135fn extract_text(resp: &Value) -> String {
136    let Some(content) = resp.get("content").and_then(|c| c.as_array()) else {
137        return String::new();
138    };
139    let mut out = String::new();
140    for block in content {
141        if block.get("type").and_then(|v| v.as_str()) == Some("text") {
142            if let Some(t) = block.get("text").and_then(|v| v.as_str()) {
143                out.push_str(t);
144            }
145        }
146    }
147    out
148}
149
150fn parse_output(fmt: &OutputFormat, text: &str) -> Parsed {
151    match fmt {
152        OutputFormat::Json => match serde_json::from_str::<Value>(text) {
153            Ok(obj) => Parsed {
154                body: obj
155                    .get("body")
156                    .and_then(|v| v.as_str())
157                    .unwrap_or("")
158                    .to_string(),
159                title: Some(
160                    obj.get("title")
161                        .and_then(|v| v.as_str())
162                        .unwrap_or("")
163                        .to_string(),
164                ),
165                tweet_count: None,
166            },
167            Err(_) => Parsed {
168                body: text.to_string(),
169                title: None,
170                tweet_count: None,
171            },
172        },
173        OutputFormat::Thread => {
174            let tweets: Vec<String> = text
175                .split("---tweet---")
176                .map(|s| s.trim().to_string())
177                .filter(|s| !s.is_empty())
178                .collect();
179            let count = tweets.len();
180            Parsed {
181                body: tweets.join("\n---tweet---\n"),
182                title: None,
183                tweet_count: Some(count),
184            }
185        }
186        OutputFormat::Text => Parsed {
187            body: text.to_string(),
188            title: None,
189            tweet_count: None,
190        },
191    }
192}
193
194fn length_ok(parsed: &Parsed, tpl: &PlatformTemplate) -> bool {
195    if let (Some(cap), Some(title)) = (tpl.title_cap, parsed.title.as_ref()) {
196        if title.chars().count() > cap {
197            return false;
198        }
199    }
200    match tpl.length_cap_field {
201        LengthCapField::PerTweet => parsed
202            .body
203            .split("---tweet---")
204            .map(|s| s.trim())
205            .all(|t| t.chars().count() <= tpl.length_cap),
206        LengthCapField::Body => parsed.body.chars().count() <= tpl.length_cap,
207    }
208}
209
210fn render_user_prompt(
211    template: &str,
212    platform: &Platform,
213    project: &Project,
214    context: &GatheredContext,
215    repo: &str,
216) -> String {
217    let hooks_lines = project
218        .hooks
219        .iter()
220        .map(|h| format!("- {h}"))
221        .collect::<Vec<_>>()
222        .join("\n");
223    let commits = context.commits.join("\n");
224    let mut flat: Vec<(&str, String)> = vec![
225        ("project.name", project.name.clone()),
226        ("project.oneliner", project.oneliner.clone()),
227        ("project.audience", project.audience.clone()),
228        ("project.hooks", hooks_lines),
229        ("context.changelog", context.changelog.clone()),
230        ("context.readme", context.readme.clone()),
231        ("context.commits", commits),
232        ("context.repo", repo.to_string()),
233        ("version", context.version.clone()),
234    ];
235    if let Platform::Reddit { subreddit } = platform {
236        flat.push(("platform.subreddit", subreddit.clone()));
237    }
238    let mut out = template.to_string();
239    for (k, v) in flat {
240        out = out.replace(&format!("{{{{{k}}}}}"), &v);
241    }
242    out
243}