Skip to main content

heartbit_core/tool/builtins/
twitter_post.rs

1#![allow(missing_docs)]
2use std::fmt::Write as _;
3use std::future::Future;
4use std::pin::Pin;
5use std::time::SystemTime;
6
7use base64::Engine;
8use hmac::{Hmac, Mac};
9use serde_json::json;
10use sha1::Sha1;
11
12use crate::error::Error;
13use crate::llm::types::ToolDefinition;
14use crate::tool::{Tool, ToolOutput};
15
16const X_API_URL: &str = "https://api.twitter.com/2/tweets";
17const MAX_TWEET_LENGTH: usize = 280;
18
19type HmacSha1 = Hmac<Sha1>;
20
21/// Per-tenant X/Twitter credentials for OAuth 1.0a signing.
22#[derive(Clone)]
23pub struct TwitterCredentials {
24    pub consumer_key: String,
25    pub consumer_secret: String,
26    pub access_token: String,
27    pub access_token_secret: String,
28}
29
30impl std::fmt::Debug for TwitterCredentials {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        f.debug_struct("TwitterCredentials")
33            .field("consumer_key", &"[REDACTED]")
34            .field("consumer_secret", &"[REDACTED]")
35            .field("access_token", &"[REDACTED]")
36            .field("access_token_secret", &"[REDACTED]")
37            .finish()
38    }
39}
40
41/// Builtin tool for posting tweets to X/Twitter via API v2.
42///
43/// Uses OAuth 1.0a signing with per-tenant credentials injected at runtime.
44/// Only instantiated when `TwitterCredentials` are provided (multi-tenant).
45pub struct TwitterPostTool {
46    credentials: TwitterCredentials,
47    client: reqwest::Client,
48}
49
50impl TwitterPostTool {
51    /// Create a `TwitterPostTool` with the given credentials.
52    ///
53    /// Panics if the HTTP client cannot be built. Use [`TwitterPostTool::try_new`]
54    /// if you need to handle the error.
55    pub fn new(credentials: TwitterCredentials) -> Self {
56        Self::try_new(credentials).expect("failed to build reqwest client")
57    }
58
59    /// Create a `TwitterPostTool` with the given credentials, returning `Err` on failure.
60    ///
61    /// Returns `Err` if the underlying HTTP client cannot be constructed
62    /// (e.g., TLS initialisation failure).
63    pub fn try_new(credentials: TwitterCredentials) -> Result<Self, crate::error::Error> {
64        let client = crate::http::vendor_client_builder()
65            .timeout(std::time::Duration::from_secs(30))
66            .build()
67            .map_err(|e| {
68                crate::error::Error::Agent(format!("failed to build reqwest client: {e}"))
69            })?;
70        Ok(Self {
71            credentials,
72            client,
73        })
74    }
75}
76
77/// Percent-encode a string per RFC 5849 (OAuth 1.0a).
78fn percent_encode(s: &str) -> String {
79    let mut encoded = String::with_capacity(s.len() * 2);
80    for byte in s.bytes() {
81        match byte {
82            b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'.' | b'_' | b'~' => {
83                encoded.push(byte as char);
84            }
85            _ => {
86                // write! to String is infallible — avoids temporary allocation from format!
87                let _ = write!(encoded, "%{byte:02X}");
88            }
89        }
90    }
91    encoded
92}
93
94/// Build the OAuth 1.0a Authorization header for a POST request.
95fn build_oauth_header(
96    url: &str,
97    consumer_key: &str,
98    consumer_secret: &str,
99    access_token: &str,
100    access_token_secret: &str,
101    nonce: &str,
102    timestamp: u64,
103) -> Result<String, Error> {
104    let oauth_params = [
105        ("oauth_consumer_key", consumer_key),
106        ("oauth_nonce", nonce),
107        ("oauth_signature_method", "HMAC-SHA1"),
108        ("oauth_timestamp", &timestamp.to_string()),
109        ("oauth_token", access_token),
110        ("oauth_version", "1.0"),
111    ];
112
113    // Build parameter string (sorted by key)
114    let param_string: String = oauth_params
115        .iter()
116        .map(|(k, v)| format!("{}={}", percent_encode(k), percent_encode(v)))
117        .collect::<Vec<_>>()
118        .join("&");
119
120    // Build signature base string: METHOD&url&params
121    let base_string = format!(
122        "POST&{}&{}",
123        percent_encode(url),
124        percent_encode(&param_string),
125    );
126
127    // Sign with HMAC-SHA1
128    let signing_key = format!(
129        "{}&{}",
130        percent_encode(consumer_secret),
131        percent_encode(access_token_secret),
132    );
133
134    let mut mac = HmacSha1::new_from_slice(signing_key.as_bytes())
135        .map_err(|e| Error::Agent(format!("HMAC key error: {e}")))?;
136    mac.update(base_string.as_bytes());
137    let signature = base64::engine::general_purpose::STANDARD.encode(mac.finalize().into_bytes());
138
139    // Build Authorization header
140    Ok(format!(
141        "OAuth oauth_consumer_key=\"{}\", \
142         oauth_nonce=\"{}\", \
143         oauth_signature=\"{}\", \
144         oauth_signature_method=\"HMAC-SHA1\", \
145         oauth_timestamp=\"{}\", \
146         oauth_token=\"{}\", \
147         oauth_version=\"1.0\"",
148        percent_encode(consumer_key),
149        percent_encode(nonce),
150        percent_encode(&signature),
151        timestamp,
152        percent_encode(access_token),
153    ))
154}
155
156impl Tool for TwitterPostTool {
157    fn definition(&self) -> ToolDefinition {
158        ToolDefinition {
159            name: "twitter_post".into(),
160            description: "Post a tweet to X/Twitter. Maximum 280 characters.".into(),
161            input_schema: json!({
162                "type": "object",
163                "properties": {
164                    "text": {
165                        "type": "string",
166                        "description": "The tweet text to post (max 280 characters)"
167                    }
168                },
169                "required": ["text"]
170            }),
171        }
172    }
173
174    fn execute(
175        &self,
176        _ctx: &crate::ExecutionContext,
177        input: serde_json::Value,
178    ) -> Pin<Box<dyn Future<Output = Result<ToolOutput, Error>> + Send + '_>> {
179        Box::pin(async move {
180            let text = input
181                .get("text")
182                .and_then(|v| v.as_str())
183                .ok_or_else(|| Error::Agent("text is required".into()))?;
184
185            if text.is_empty() {
186                return Ok(ToolOutput::error("text must not be empty"));
187            }
188
189            let char_count = text.chars().count();
190            if char_count > MAX_TWEET_LENGTH {
191                return Ok(ToolOutput::error(format!(
192                    "Tweet exceeds {MAX_TWEET_LENGTH} characters (got {char_count}). \
193                     Please shorten your tweet."
194                )));
195            }
196
197            // Generate OAuth nonce and timestamp
198            let timestamp = SystemTime::now()
199                .duration_since(SystemTime::UNIX_EPOCH)
200                .map_err(|e| Error::Agent(format!("system time error: {e}")))?
201                .as_secs();
202
203            // UUID v4 provides cryptographically random nonce (required by RFC 5849)
204            let nonce = uuid::Uuid::new_v4().to_string().replace('-', "");
205
206            let auth_header = build_oauth_header(
207                X_API_URL,
208                &self.credentials.consumer_key,
209                &self.credentials.consumer_secret,
210                &self.credentials.access_token,
211                &self.credentials.access_token_secret,
212                &nonce,
213                timestamp,
214            )?;
215
216            let body = json!({ "text": text });
217
218            let response = self
219                .client
220                .post(X_API_URL)
221                .header("Authorization", &auth_header)
222                .header("Content-Type", "application/json")
223                .json(&body)
224                .send()
225                .await
226                .map_err(|e| Error::Agent(format!("X API request failed: {e}")))?;
227
228            let status = response.status();
229            // SECURITY (F-NET-1): cap response body. Tweet responses are tiny
230            // (well under 1 MiB) — 256 KiB is generous and bounds memory.
231            let (body_bytes, _truncated) = crate::http::read_body_capped(response, 256 * 1024)
232                .await
233                .map_err(|e| Error::Agent(format!("Failed to read X API response: {e}")))?;
234            let response_body: serde_json::Value = serde_json::from_slice(&body_bytes)
235                .map_err(|e| Error::Agent(format!("Failed to parse X API response: {e}")))?;
236
237            if !status.is_success() {
238                let detail = response_body
239                    .get("detail")
240                    .and_then(|v| v.as_str())
241                    .or_else(|| response_body.get("title").and_then(|v| v.as_str()))
242                    .unwrap_or("Unknown error");
243                return Ok(ToolOutput::error(format!(
244                    "X API error (HTTP {}): {detail}",
245                    status.as_u16()
246                )));
247            }
248
249            // Extract tweet ID from response
250            let tweet_id = response_body
251                .get("data")
252                .and_then(|d| d.get("id"))
253                .and_then(|v| v.as_str())
254                .unwrap_or("unknown");
255
256            Ok(ToolOutput::success(format!(
257                "Tweet posted successfully!\n\
258                 Tweet ID: {tweet_id}\n\
259                 URL: https://x.com/i/status/{tweet_id}\n\
260                 Text: {text}"
261            )))
262        })
263    }
264}
265
266#[cfg(test)]
267mod tests {
268    use super::*;
269
270    fn test_credentials() -> TwitterCredentials {
271        TwitterCredentials {
272            consumer_key: "test_consumer_key".into(),
273            consumer_secret: "test_consumer_secret".into(),
274            access_token: "test_access_token".into(),
275            access_token_secret: "test_access_token_secret".into(),
276        }
277    }
278
279    #[test]
280    fn definition_has_correct_name() {
281        let tool = TwitterPostTool::new(test_credentials());
282        assert_eq!(tool.definition().name, "twitter_post");
283    }
284
285    #[test]
286    fn definition_requires_text() {
287        let tool = TwitterPostTool::new(test_credentials());
288        let schema = &tool.definition().input_schema;
289        let required = schema["required"].as_array().unwrap();
290        assert_eq!(required.len(), 1);
291        assert_eq!(required[0], "text");
292    }
293
294    #[test]
295    fn percent_encode_unreserved() {
296        assert_eq!(percent_encode("abc123"), "abc123");
297        assert_eq!(
298            percent_encode("hello-world_test.v2~"),
299            "hello-world_test.v2~"
300        );
301    }
302
303    #[test]
304    fn percent_encode_reserved() {
305        assert_eq!(percent_encode("hello world"), "hello%20world");
306        assert_eq!(percent_encode("a&b=c"), "a%26b%3Dc");
307        assert_eq!(percent_encode("100%"), "100%25");
308    }
309
310    #[test]
311    fn percent_encode_special_chars() {
312        assert_eq!(percent_encode("/"), "%2F");
313        assert_eq!(percent_encode(":"), "%3A");
314        assert_eq!(percent_encode("@"), "%40");
315    }
316
317    #[test]
318    fn build_oauth_header_produces_valid_format() {
319        let header = build_oauth_header(
320            "https://api.twitter.com/2/tweets",
321            "consumer_key",
322            "consumer_secret",
323            "access_token",
324            "access_token_secret",
325            "testnonce123",
326            1234567890,
327        )
328        .unwrap();
329
330        assert!(header.starts_with("OAuth "));
331        assert!(header.contains("oauth_consumer_key=\"consumer_key\""));
332        assert!(header.contains("oauth_nonce=\"testnonce123\""));
333        assert!(header.contains("oauth_signature_method=\"HMAC-SHA1\""));
334        assert!(header.contains("oauth_timestamp=\"1234567890\""));
335        assert!(header.contains("oauth_token=\"access_token\""));
336        assert!(header.contains("oauth_version=\"1.0\""));
337        assert!(header.contains("oauth_signature=\""));
338    }
339
340    #[test]
341    fn build_oauth_header_signature_is_deterministic() {
342        let h1 = build_oauth_header(X_API_URL, "ck", "cs", "at", "ats", "nonce", 1000).unwrap();
343        let h2 = build_oauth_header(X_API_URL, "ck", "cs", "at", "ats", "nonce", 1000).unwrap();
344        assert_eq!(h1, h2);
345    }
346
347    #[test]
348    fn build_oauth_header_different_nonce_produces_different_signature() {
349        let h1 = build_oauth_header(X_API_URL, "ck", "cs", "at", "ats", "nonce1", 1000).unwrap();
350        let h2 = build_oauth_header(X_API_URL, "ck", "cs", "at", "ats", "nonce2", 1000).unwrap();
351        assert_ne!(h1, h2);
352    }
353
354    #[tokio::test]
355    async fn rejects_empty_text() {
356        let tool = TwitterPostTool::new(test_credentials());
357        let result = tool
358            .execute(&crate::ExecutionContext::default(), json!({"text": ""}))
359            .await
360            .unwrap();
361        assert!(result.is_error);
362        assert!(result.content.contains("must not be empty"));
363    }
364
365    #[tokio::test]
366    async fn rejects_text_too_long() {
367        let tool = TwitterPostTool::new(test_credentials());
368        let long = "a".repeat(281);
369        let result = tool
370            .execute(&crate::ExecutionContext::default(), json!({"text": long}))
371            .await
372            .unwrap();
373        assert!(result.is_error);
374        assert!(result.content.contains("exceeds 280 characters"));
375    }
376
377    #[tokio::test]
378    async fn rejects_missing_text() {
379        let tool = TwitterPostTool::new(test_credentials());
380        let result = tool
381            .execute(&crate::ExecutionContext::default(), json!({}))
382            .await;
383        assert!(result.is_err());
384        let err = result.unwrap_err().to_string();
385        assert!(err.contains("text is required"), "got: {err}");
386    }
387
388    #[test]
389    fn credentials_debug_redacts_secrets() {
390        let creds = test_credentials();
391        let debug = format!("{creds:?}");
392        assert!(debug.contains("[REDACTED]"));
393        assert!(!debug.contains("test_consumer_key"));
394        assert!(!debug.contains("test_consumer_secret"));
395    }
396
397    #[tokio::test]
398    async fn accepts_280_chars() {
399        // 280 chars should pass validation (will fail at HTTP level, but that's expected)
400        let tool = TwitterPostTool::new(test_credentials());
401        let text = "a".repeat(280);
402        let result = tool
403            .execute(&crate::ExecutionContext::default(), json!({"text": text}))
404            .await;
405        // Should not be a validation error — will fail at HTTP level
406        match result {
407            Ok(output) => {
408                // Network error is fine, but should NOT be a validation error
409                if output.is_error {
410                    assert!(
411                        !output.content.contains("exceeds"),
412                        "280 chars should not be rejected: {}",
413                        output.content
414                    );
415                }
416            }
417            Err(_) => {
418                // Network error is expected with fake credentials
419            }
420        }
421    }
422}