Skip to main content

tuitbot_server/routes/
assist.rs

1//! AI assist endpoints for on-demand content generation.
2//!
3//! These are stateless: they generate content and return it without posting.
4//! The user decides what to do with the results.
5
6use std::sync::Arc;
7
8use axum::extract::State;
9use axum::http::StatusCode;
10use axum::Json;
11use serde::{Deserialize, Serialize};
12
13use tuitbot_core::content::ContentGenerator;
14use tuitbot_core::context::winning_dna;
15use tuitbot_core::storage;
16
17use crate::account::AccountContext;
18use crate::error::ApiError;
19use crate::state::AppState;
20
21// ---------------------------------------------------------------------------
22// Helpers
23// ---------------------------------------------------------------------------
24
25async fn get_generator(
26    state: &AppState,
27    account_id: &str,
28) -> Result<Arc<ContentGenerator>, ApiError> {
29    state
30        .get_or_create_content_generator(account_id)
31        .await
32        .map_err(ApiError::BadRequest)
33}
34
35/// Resolve optional RAG context from the vault for composer assist handlers.
36///
37/// Loads the business profile's keyword set, queries winning ancestors and
38/// content seeds via `build_draft_context()`, and returns the formatted
39/// prompt block. Returns `None` (fail-open) on any error or when no
40/// relevant context exists.
41async fn resolve_composer_rag_context(state: &AppState, account_id: &str) -> Option<String> {
42    let config = match state.load_effective_config(account_id).await {
43        Ok(c) => c,
44        Err(e) => {
45            tracing::warn!("composer RAG: failed to load config: {e}");
46            return None;
47        }
48    };
49
50    let keywords = config.business.draft_context_keywords();
51    if keywords.is_empty() {
52        return None;
53    }
54
55    let draft_context = match winning_dna::build_draft_context(
56        &state.db,
57        &keywords,
58        winning_dna::MAX_ANCESTORS,
59        winning_dna::RECENCY_HALF_LIFE_DAYS,
60    )
61    .await
62    {
63        Ok(ctx) => ctx,
64        Err(e) => {
65            tracing::warn!("composer RAG: failed to build draft context: {e}");
66            return None;
67        }
68    };
69
70    if draft_context.prompt_block.is_empty() {
71        None
72    } else {
73        Some(draft_context.prompt_block)
74    }
75}
76
77// ---------------------------------------------------------------------------
78// POST /api/assist/tweet
79// ---------------------------------------------------------------------------
80
81#[derive(Deserialize)]
82pub struct AssistTweetRequest {
83    pub topic: String,
84}
85
86#[derive(Serialize)]
87pub struct AssistTweetResponse {
88    pub content: String,
89    pub topic: String,
90}
91
92pub async fn assist_tweet(
93    State(state): State<Arc<AppState>>,
94    ctx: AccountContext,
95    Json(body): Json<AssistTweetRequest>,
96) -> Result<Json<AssistTweetResponse>, ApiError> {
97    let gen = get_generator(&state, &ctx.account_id).await?;
98    let rag_context = resolve_composer_rag_context(&state, &ctx.account_id).await;
99
100    let output = gen
101        .generate_tweet_with_context(&body.topic, None, rag_context.as_deref())
102        .await
103        .map_err(|e| ApiError::Internal(e.to_string()))?;
104
105    Ok(Json(AssistTweetResponse {
106        content: output.text,
107        topic: body.topic,
108    }))
109}
110
111// ---------------------------------------------------------------------------
112// POST /api/assist/reply
113// ---------------------------------------------------------------------------
114
115#[derive(Deserialize)]
116pub struct AssistReplyRequest {
117    pub tweet_text: String,
118    pub tweet_author: String,
119    #[serde(default)]
120    pub mention_product: bool,
121}
122
123#[derive(Serialize)]
124pub struct AssistReplyResponse {
125    pub content: String,
126}
127
128pub async fn assist_reply(
129    State(state): State<Arc<AppState>>,
130    ctx: AccountContext,
131    Json(body): Json<AssistReplyRequest>,
132) -> Result<Json<AssistReplyResponse>, ApiError> {
133    let gen = get_generator(&state, &ctx.account_id).await?;
134
135    let output = gen
136        .generate_reply(&body.tweet_text, &body.tweet_author, body.mention_product)
137        .await
138        .map_err(|e| ApiError::Internal(e.to_string()))?;
139
140    Ok(Json(AssistReplyResponse {
141        content: output.text,
142    }))
143}
144
145// ---------------------------------------------------------------------------
146// POST /api/assist/thread
147// ---------------------------------------------------------------------------
148
149#[derive(Deserialize)]
150pub struct AssistThreadRequest {
151    pub topic: String,
152}
153
154#[derive(Serialize)]
155pub struct AssistThreadResponse {
156    pub tweets: Vec<String>,
157    pub topic: String,
158}
159
160pub async fn assist_thread(
161    State(state): State<Arc<AppState>>,
162    ctx: AccountContext,
163    Json(body): Json<AssistThreadRequest>,
164) -> Result<Json<AssistThreadResponse>, ApiError> {
165    let gen = get_generator(&state, &ctx.account_id).await?;
166    let rag_context = resolve_composer_rag_context(&state, &ctx.account_id).await;
167
168    let output = gen
169        .generate_thread_with_context(&body.topic, None, rag_context.as_deref())
170        .await
171        .map_err(|e| ApiError::Internal(e.to_string()))?;
172
173    Ok(Json(AssistThreadResponse {
174        tweets: output.tweets,
175        topic: body.topic,
176    }))
177}
178
179// ---------------------------------------------------------------------------
180// POST /api/assist/improve
181// ---------------------------------------------------------------------------
182
183#[derive(Deserialize)]
184pub struct AssistImproveRequest {
185    pub draft: String,
186    #[serde(default)]
187    pub context: Option<String>,
188}
189
190#[derive(Serialize)]
191pub struct AssistImproveResponse {
192    pub content: String,
193}
194
195pub async fn assist_improve(
196    State(state): State<Arc<AppState>>,
197    ctx: AccountContext,
198    Json(body): Json<AssistImproveRequest>,
199) -> Result<Json<AssistImproveResponse>, ApiError> {
200    let gen = get_generator(&state, &ctx.account_id).await?;
201    let rag_context = resolve_composer_rag_context(&state, &ctx.account_id).await;
202
203    let output = gen
204        .improve_draft_with_context(&body.draft, body.context.as_deref(), rag_context.as_deref())
205        .await
206        .map_err(|e| ApiError::Internal(e.to_string()))?;
207
208    Ok(Json(AssistImproveResponse {
209        content: output.text,
210    }))
211}
212
213// ---------------------------------------------------------------------------
214// GET /api/assist/topics
215// ---------------------------------------------------------------------------
216
217#[derive(Serialize)]
218pub struct AssistTopicsResponse {
219    pub topics: Vec<TopicRecommendation>,
220}
221
222#[derive(Serialize)]
223pub struct TopicRecommendation {
224    pub topic: String,
225    pub score: f64,
226}
227
228pub async fn assist_topics(
229    State(state): State<Arc<AppState>>,
230    ctx: AccountContext,
231) -> Result<Json<AssistTopicsResponse>, ApiError> {
232    let top = storage::analytics::get_top_topics_for(&state.db, &ctx.account_id, 10).await?;
233
234    let topics = top
235        .into_iter()
236        .map(|cs| TopicRecommendation {
237            topic: cs.topic,
238            score: cs.avg_performance,
239        })
240        .collect();
241
242    Ok(Json(AssistTopicsResponse { topics }))
243}
244
245// ---------------------------------------------------------------------------
246// GET /api/assist/optimal-times
247// ---------------------------------------------------------------------------
248
249#[derive(Serialize)]
250pub struct OptimalTimesResponse {
251    pub times: Vec<OptimalTime>,
252}
253
254#[derive(Serialize)]
255pub struct OptimalTime {
256    pub hour: u32,
257    pub avg_engagement: f64,
258    pub post_count: i64,
259}
260
261pub async fn assist_optimal_times(
262    State(state): State<Arc<AppState>>,
263    ctx: AccountContext,
264) -> Result<Json<OptimalTimesResponse>, ApiError> {
265    let rows =
266        storage::analytics::get_optimal_posting_times_for(&state.db, &ctx.account_id).await?;
267
268    let times = rows
269        .into_iter()
270        .map(|r| OptimalTime {
271            hour: r.hour as u32,
272            avg_engagement: r.avg_engagement,
273            post_count: r.post_count,
274        })
275        .collect();
276
277    Ok(Json(OptimalTimesResponse { times }))
278}
279
280// ---------------------------------------------------------------------------
281// GET /api/assist/mode
282// ---------------------------------------------------------------------------
283
284#[derive(Serialize)]
285pub struct ModeResponse {
286    pub mode: String,
287    pub approval_mode: bool,
288}
289
290pub async fn get_mode(
291    State(state): State<Arc<AppState>>,
292    ctx: AccountContext,
293) -> Result<(StatusCode, Json<ModeResponse>), ApiError> {
294    let config = crate::routes::content::read_effective_config(&state, &ctx.account_id).await?;
295
296    Ok((
297        StatusCode::OK,
298        Json(ModeResponse {
299            mode: config.mode.to_string(),
300            approval_mode: config.effective_approval_mode(),
301        }),
302    ))
303}
304
305#[cfg(test)]
306mod tests {
307    use super::*;
308
309    use std::collections::HashMap;
310    use std::path::PathBuf;
311
312    use tokio::sync::{broadcast, Mutex, RwLock};
313
314    use crate::ws::AccountWsEvent;
315
316    /// Build a minimal `AppState` for testing the RAG resolver.
317    async fn test_state(config_path: PathBuf) -> AppState {
318        let db = tuitbot_core::storage::init_test_db()
319            .await
320            .expect("init test db");
321        let (event_tx, _) = broadcast::channel::<AccountWsEvent>(16);
322        AppState {
323            db,
324            config_path: config_path.clone(),
325            data_dir: config_path.parent().unwrap_or(&config_path).to_path_buf(),
326            event_tx,
327            api_token: "test-token".to_string(),
328            passphrase_hash: RwLock::new(None),
329            passphrase_hash_mtime: RwLock::new(None),
330            bind_host: "127.0.0.1".to_string(),
331            bind_port: 3001,
332            login_attempts: Mutex::new(HashMap::new()),
333            runtimes: Mutex::new(HashMap::new()),
334            content_generators: Mutex::new(HashMap::new()),
335            circuit_breaker: None,
336            watchtower_cancel: None,
337            content_sources: Default::default(),
338            connector_config: Default::default(),
339            deployment_mode: Default::default(),
340
341            pending_oauth: Mutex::new(HashMap::new()),
342            token_managers: Mutex::new(HashMap::new()),
343            x_client_id: String::new(),
344        }
345    }
346
347    #[tokio::test]
348    async fn resolve_rag_returns_none_when_config_missing() {
349        let state = test_state(PathBuf::from("/nonexistent/config.toml")).await;
350        let result = resolve_composer_rag_context(&state, "test-account").await;
351        assert!(
352            result.is_none(),
353            "should return None when config is missing"
354        );
355    }
356
357    #[tokio::test]
358    async fn resolve_rag_returns_none_when_db_empty() {
359        let dir = tempfile::tempdir().expect("create temp dir");
360        let config_path = dir.path().join("config.toml");
361        std::fs::write(
362            &config_path,
363            "[business]\nproduct_name = \"TestProduct\"\nproduct_keywords = [\"rust\", \"testing\"]\n",
364        )
365        .expect("write config");
366
367        let state = test_state(config_path).await;
368        let result = resolve_composer_rag_context(&state, "test-account").await;
369        assert!(
370            result.is_none(),
371            "should return None when DB has no ancestor data"
372        );
373    }
374
375    #[tokio::test]
376    async fn resolve_rag_returns_none_when_no_keywords() {
377        let dir = tempfile::tempdir().expect("create temp dir");
378        let config_path = dir.path().join("config.toml");
379        // Empty business profile → no keywords → early return None.
380        std::fs::write(&config_path, "[business]\nproduct_name = \"Empty\"\n")
381            .expect("write config");
382
383        let state = test_state(config_path).await;
384        let result = resolve_composer_rag_context(&state, "test-account").await;
385        assert!(
386            result.is_none(),
387            "should return None when keywords are empty"
388        );
389    }
390}