Skip to main content

xcom_rs/tweets/
commands.rs

1use super::{
2    ledger::IdempotencyLedger,
3    models::{Tweet, TweetFields, TweetMeta},
4};
5use anyhow::{anyhow, Context, Result};
6use serde::{Deserialize, Serialize};
7use std::str::FromStr;
8use uuid::Uuid;
9
10/// Custom error type for idempotency conflicts
11#[derive(Debug)]
12pub struct IdempotencyConflictError {
13    pub client_request_id: String,
14}
15
16impl std::fmt::Display for IdempotencyConflictError {
17    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18        write!(
19            f,
20            "Operation with client_request_id '{}' already exists",
21            self.client_request_id
22        )
23    }
24}
25
26impl std::error::Error for IdempotencyConflictError {}
27
28/// Policy for handling existing operations with the same client_request_id
29#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub enum IfExistsPolicy {
31    /// Return the existing result without error
32    Return,
33    /// Return an error if operation already exists
34    Error,
35}
36
37impl FromStr for IfExistsPolicy {
38    type Err = anyhow::Error;
39
40    fn from_str(s: &str) -> Result<Self> {
41        match s {
42            "return" => Ok(Self::Return),
43            "error" => Ok(Self::Error),
44            _ => Err(anyhow!(
45                "Invalid if-exists policy: {}. Valid values: return, error",
46                s
47            )),
48        }
49    }
50}
51
52impl IfExistsPolicy {
53    pub fn as_str(&self) -> &'static str {
54        match self {
55            Self::Return => "return",
56            Self::Error => "error",
57        }
58    }
59}
60
61/// Arguments for creating a tweet
62#[derive(Debug, Clone)]
63pub struct CreateArgs {
64    pub text: String,
65    pub client_request_id: Option<String>,
66    pub if_exists: IfExistsPolicy,
67}
68
69/// Arguments for listing tweets
70#[derive(Debug, Clone)]
71pub struct ListArgs {
72    pub fields: Vec<TweetFields>,
73    pub limit: Option<usize>,
74    pub cursor: Option<String>,
75}
76
77/// Result of a create operation
78#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct CreateResult {
80    pub tweet: Tweet,
81    pub meta: TweetMeta,
82}
83
84/// Pagination metadata
85#[derive(Debug, Clone, Serialize, Deserialize)]
86#[serde(rename_all = "camelCase")]
87pub struct PaginationMeta {
88    #[serde(skip_serializing_if = "Option::is_none")]
89    pub next_cursor: Option<String>,
90    #[serde(skip_serializing_if = "Option::is_none")]
91    pub prev_cursor: Option<String>,
92}
93
94/// Result of a list operation
95#[derive(Debug, Clone, Serialize, Deserialize)]
96pub struct ListResult {
97    pub tweets: Vec<Tweet>,
98    #[serde(skip_serializing_if = "Option::is_none")]
99    pub meta: Option<ListResultMeta>,
100}
101
102/// Metadata for list results
103#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct ListResultMeta {
105    pub pagination: PaginationMeta,
106}
107
108/// Error classification for retry logic
109#[derive(Debug, Clone, Copy, PartialEq, Eq)]
110pub enum ErrorKind {
111    /// Retryable errors (429, 5xx)
112    Retryable,
113    /// Non-retryable client errors (4xx except 429)
114    NonRetryable,
115    /// Network/timeout errors
116    Timeout,
117}
118
119/// Classified error with retry information
120#[derive(Debug)]
121pub struct ClassifiedError {
122    pub kind: ErrorKind,
123    pub status_code: Option<u16>,
124    pub message: String,
125    pub is_retryable: bool,
126    pub retry_after_ms: Option<u64>,
127}
128
129impl ClassifiedError {
130    pub fn from_status_code(status_code: u16, message: String) -> Self {
131        let (kind, is_retryable) = match status_code {
132            429 => (ErrorKind::Retryable, true),
133            500..=599 => (ErrorKind::Retryable, true),
134            400..=499 => (ErrorKind::NonRetryable, false),
135            _ => (ErrorKind::NonRetryable, false),
136        };
137
138        Self {
139            kind,
140            status_code: Some(status_code),
141            message,
142            is_retryable,
143            retry_after_ms: None,
144        }
145    }
146
147    pub fn timeout(message: String) -> Self {
148        Self {
149            kind: ErrorKind::Timeout,
150            status_code: None,
151            message,
152            is_retryable: true,
153            retry_after_ms: None,
154        }
155    }
156
157    pub fn with_retry_after(mut self, retry_after_ms: u64) -> Self {
158        self.retry_after_ms = Some(retry_after_ms);
159        self
160    }
161
162    /// Convert to ErrorCode for protocol
163    pub fn to_error_code(&self) -> crate::protocol::ErrorCode {
164        use crate::protocol::ErrorCode;
165        match self.kind {
166            ErrorKind::Retryable => {
167                if let Some(429) = self.status_code {
168                    ErrorCode::RateLimitExceeded
169                } else {
170                    ErrorCode::ServiceUnavailable
171                }
172            }
173            ErrorKind::Timeout => ErrorCode::NetworkError,
174            ErrorKind::NonRetryable => ErrorCode::InternalError,
175        }
176    }
177}
178
179impl std::fmt::Display for ClassifiedError {
180    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
181        write!(f, "{}", self.message)
182    }
183}
184
185impl std::error::Error for ClassifiedError {}
186
187/// Main tweets command handler
188pub struct TweetCommand {
189    ledger: IdempotencyLedger,
190}
191
192impl TweetCommand {
193    /// Create a new tweet command handler
194    pub fn new(ledger: IdempotencyLedger) -> Self {
195        Self { ledger }
196    }
197
198    /// Create a tweet with idempotency support
199    pub fn create(&self, args: CreateArgs) -> Result<CreateResult> {
200        // Check for simulated errors via environment variables (for testing)
201        if let Ok(error_type) = std::env::var("XCOM_SIMULATE_ERROR") {
202            match error_type.as_str() {
203                "rate_limit" => {
204                    let retry_after = std::env::var("XCOM_RETRY_AFTER_MS")
205                        .ok()
206                        .and_then(|s| s.parse::<u64>().ok())
207                        .unwrap_or(60000);
208                    return Err(ClassifiedError::from_status_code(
209                        429,
210                        "Rate limit exceeded".to_string(),
211                    )
212                    .with_retry_after(retry_after)
213                    .into());
214                }
215                "server_error" => {
216                    return Err(ClassifiedError::from_status_code(
217                        500,
218                        "Internal server error".to_string(),
219                    )
220                    .into());
221                }
222                "timeout" => {
223                    return Err(ClassifiedError::timeout("Request timeout".to_string()).into());
224                }
225                _ => {
226                    // Continue with normal flow for unknown error types
227                }
228            }
229        }
230
231        // Generate client_request_id if not provided
232        let client_request_id = args
233            .client_request_id
234            .unwrap_or_else(|| Uuid::new_v4().to_string());
235
236        // Compute request hash for storing (but not for lookup key)
237        let request_hash = IdempotencyLedger::compute_request_hash(&args.text);
238
239        // Check ledger for existing operation by client_request_id only
240        if let Some(entry) = self
241            .ledger
242            .lookup(&client_request_id)
243            .context("Failed to lookup operation in ledger")?
244        {
245            // Found existing operation with this client_request_id
246            match args.if_exists {
247                IfExistsPolicy::Return => {
248                    // Return cached result (even if parameters differ)
249                    let mut tweet = Tweet::new(entry.tweet_id.clone());
250                    tweet.text = Some(args.text.clone());
251
252                    let meta = TweetMeta {
253                        client_request_id: client_request_id.clone(),
254                        from_cache: Some(true),
255                    };
256
257                    return Ok(CreateResult { tweet, meta });
258                }
259                IfExistsPolicy::Error => {
260                    // Return error for duplicate client_request_id
261                    return Err(IdempotencyConflictError {
262                        client_request_id: client_request_id.clone(),
263                    }
264                    .into());
265                }
266            }
267        }
268
269        // Simulate tweet creation (in real implementation, would call X API)
270        let tweet_id = format!("tweet_{}", Uuid::new_v4());
271        let mut tweet = Tweet::new(tweet_id.clone());
272        tweet.text = Some(args.text);
273
274        // Record successful operation in ledger
275        self.ledger
276            .record(&client_request_id, &request_hash, &tweet_id, "success")
277            .context("Failed to record operation in ledger")?;
278
279        let meta = TweetMeta {
280            client_request_id,
281            from_cache: None,
282        };
283
284        Ok(CreateResult { tweet, meta })
285    }
286
287    /// List tweets with field projection and pagination
288    pub fn list(&self, args: ListArgs) -> Result<ListResult> {
289        // Check for simulated errors via environment variables (for testing)
290        if let Ok(error_type) = std::env::var("XCOM_SIMULATE_ERROR") {
291            match error_type.as_str() {
292                "rate_limit" => {
293                    let retry_after = std::env::var("XCOM_RETRY_AFTER_MS")
294                        .ok()
295                        .and_then(|s| s.parse::<u64>().ok())
296                        .unwrap_or(60000);
297                    return Err(ClassifiedError::from_status_code(
298                        429,
299                        "Rate limit exceeded".to_string(),
300                    )
301                    .with_retry_after(retry_after)
302                    .into());
303                }
304                "server_error" => {
305                    return Err(ClassifiedError::from_status_code(
306                        500,
307                        "Internal server error".to_string(),
308                    )
309                    .into());
310                }
311                "timeout" => {
312                    return Err(ClassifiedError::timeout("Request timeout".to_string()).into());
313                }
314                _ => {
315                    // Continue with normal flow for unknown error types
316                }
317            }
318        }
319
320        // Simulate fetching tweets (in real implementation, would call X API)
321        let limit = args.limit.unwrap_or(10);
322
323        // Parse cursor to determine starting offset
324        let offset = if let Some(cursor) = &args.cursor {
325            // Cursor format is "cursor_{offset}"
326            cursor
327                .strip_prefix("cursor_")
328                .and_then(|s| s.parse::<usize>().ok())
329                .unwrap_or(0)
330        } else {
331            0
332        };
333
334        let mut tweets = Vec::new();
335        for i in offset..(offset + limit) {
336            let mut tweet = Tweet::new(format!("tweet_{}", i));
337            tweet.text = Some(format!("Tweet text {}", i));
338            tweet.author_id = Some(format!("user_{}", i));
339            tweet.created_at = Some("2024-01-01T00:00:00Z".to_string());
340
341            // Apply field projection
342            let projected = tweet.project(&args.fields);
343            tweets.push(projected);
344        }
345
346        // Create pagination metadata
347        let next_cursor = if tweets.len() == limit {
348            Some(format!("cursor_{}", offset + limit))
349        } else {
350            None
351        };
352
353        let prev_cursor = if offset > 0 {
354            Some(format!("cursor_{}", offset.saturating_sub(limit)))
355        } else {
356            None
357        };
358
359        let meta = Some(ListResultMeta {
360            pagination: PaginationMeta {
361                next_cursor,
362                prev_cursor,
363            },
364        });
365
366        Ok(ListResult { tweets, meta })
367    }
368}
369
370#[cfg(test)]
371mod tests {
372    use super::*;
373    use tempfile::TempDir;
374
375    fn create_test_command() -> (TweetCommand, TempDir) {
376        let temp_dir = TempDir::new().unwrap();
377        let db_path = temp_dir.path().join("test.db");
378        let ledger = IdempotencyLedger::new(Some(&db_path)).unwrap();
379        let cmd = TweetCommand::new(ledger);
380        (cmd, temp_dir)
381    }
382
383    #[test]
384    fn test_create_generates_client_request_id() {
385        let (cmd, _temp) = create_test_command();
386
387        let args = CreateArgs {
388            text: "Hello world".to_string(),
389            client_request_id: None,
390            if_exists: IfExistsPolicy::Return,
391        };
392
393        let result = cmd.create(args).unwrap();
394        assert!(!result.meta.client_request_id.is_empty());
395        assert_eq!(result.tweet.text, Some("Hello world".to_string()));
396    }
397
398    #[test]
399    fn test_create_with_explicit_client_request_id() {
400        let (cmd, _temp) = create_test_command();
401
402        let args = CreateArgs {
403            text: "Hello world".to_string(),
404            client_request_id: Some("my-request-id".to_string()),
405            if_exists: IfExistsPolicy::Return,
406        };
407
408        let result = cmd.create(args).unwrap();
409        assert_eq!(result.meta.client_request_id, "my-request-id");
410    }
411
412    #[test]
413    fn test_create_idempotency_return_policy() {
414        let (cmd, _temp) = create_test_command();
415
416        let args = CreateArgs {
417            text: "Hello world".to_string(),
418            client_request_id: Some("test-123".to_string()),
419            if_exists: IfExistsPolicy::Return,
420        };
421
422        // First call
423        let result1 = cmd.create(args.clone()).unwrap();
424        let tweet_id1 = result1.tweet.id.clone();
425
426        // Second call with same ID and text should return cached result
427        let result2 = cmd.create(args).unwrap();
428        assert_eq!(result2.tweet.id, tweet_id1);
429        assert_eq!(result2.meta.from_cache, Some(true));
430    }
431
432    #[test]
433    fn test_create_idempotency_error_policy() {
434        let (cmd, _temp) = create_test_command();
435
436        let args = CreateArgs {
437            text: "Hello world".to_string(),
438            client_request_id: Some("test-456".to_string()),
439            if_exists: IfExistsPolicy::Error,
440        };
441
442        // First call succeeds
443        cmd.create(args.clone()).unwrap();
444
445        // Second call should error
446        let result = cmd.create(args);
447        assert!(result.is_err());
448        assert!(result.unwrap_err().to_string().contains("already exists"));
449    }
450
451    #[test]
452    fn test_list_with_field_projection() {
453        let (cmd, _temp) = create_test_command();
454
455        let args = ListArgs {
456            fields: vec![TweetFields::Id, TweetFields::Text],
457            limit: Some(5),
458            cursor: None,
459        };
460
461        let result = cmd.list(args).unwrap();
462        assert_eq!(result.tweets.len(), 5);
463
464        // Check that only requested fields are present
465        for tweet in &result.tweets {
466            assert!(!tweet.id.is_empty());
467            assert!(tweet.text.is_some());
468            assert!(tweet.author_id.is_none()); // Not requested
469        }
470    }
471
472    #[test]
473    fn test_list_pagination() {
474        let (cmd, _temp) = create_test_command();
475
476        let args = ListArgs {
477            fields: TweetFields::default_fields(),
478            limit: Some(10),
479            cursor: None,
480        };
481
482        let result = cmd.list(args).unwrap();
483        assert_eq!(result.tweets.len(), 10);
484        assert!(result.meta.is_some());
485        let meta = result.meta.unwrap();
486        assert!(meta.pagination.next_cursor.is_some());
487        assert_eq!(meta.pagination.next_cursor, Some("cursor_10".to_string()));
488        assert!(meta.pagination.prev_cursor.is_none());
489    }
490
491    #[test]
492    fn test_error_classification() {
493        let err_429 = ClassifiedError::from_status_code(429, "Rate limit".to_string());
494        assert_eq!(err_429.kind, ErrorKind::Retryable);
495        assert!(err_429.is_retryable);
496
497        let err_500 = ClassifiedError::from_status_code(500, "Server error".to_string());
498        assert_eq!(err_500.kind, ErrorKind::Retryable);
499        assert!(err_500.is_retryable);
500
501        let err_400 = ClassifiedError::from_status_code(400, "Bad request".to_string());
502        assert_eq!(err_400.kind, ErrorKind::NonRetryable);
503        assert!(!err_400.is_retryable);
504
505        let err_timeout = ClassifiedError::timeout("Timeout".to_string());
506        assert_eq!(err_timeout.kind, ErrorKind::Timeout);
507        assert!(err_timeout.is_retryable);
508    }
509
510    #[test]
511    fn test_if_exists_policy_from_str() {
512        assert_eq!(
513            IfExistsPolicy::from_str("return").unwrap(),
514            IfExistsPolicy::Return
515        );
516        assert_eq!(
517            IfExistsPolicy::from_str("error").unwrap(),
518            IfExistsPolicy::Error
519        );
520        assert!(IfExistsPolicy::from_str("invalid").is_err());
521    }
522}