1use std::time::{Duration, Instant};
27
28use base64;
29use serde::de::Error;
30use serde::{Deserialize, Deserializer};
31
32use crate::common::*;
33use crate::{auth, error, links};
34
35use mime;
36
37pub mod media_types {
47 use mime::{self, Mime};
48
49 pub fn image_png() -> Mime {
51 mime::IMAGE_PNG
52 }
53
54 pub fn image_jpg() -> Mime {
56 mime::IMAGE_JPEG
57 }
58
59 pub fn image_webp() -> Mime {
61 "image/webp".parse().unwrap()
62 }
63
64 pub fn image_gif() -> Mime {
66 mime::IMAGE_GIF
67 }
68
69 pub fn video_mp4() -> Mime {
71 "video/mp4".parse().unwrap()
72 }
73}
74
75#[derive(Debug, Clone, PartialEq)]
77pub enum ProgressInfo {
78 Pending(u64),
80 InProgress(u64),
82 Failed(error::MediaError),
84 Success,
86}
87
88#[derive(Debug, Deserialize)]
89enum RawProgressInfoTag {
90 #[serde(rename = "pending")]
91 Pending,
92 #[serde(rename = "in_progress")]
93 InProgress,
94 #[serde(rename = "failed")]
95 Failed,
96 #[serde(rename = "succeeded")]
97 Success,
98}
99
100#[derive(Debug, Deserialize)]
101struct RawProgressInfo {
102 state: RawProgressInfoTag,
103 progress_percent: Option<f64>,
104 check_after_secs: Option<u64>,
105 error: Option<error::MediaError>,
106}
107
108impl<'de> Deserialize<'de> for ProgressInfo {
109 fn deserialize<D>(deser: D) -> Result<ProgressInfo, D::Error>
110 where
111 D: Deserializer<'de>,
112 {
113 use self::RawProgressInfoTag::*;
114 let raw = RawProgressInfo::deserialize(deser)?;
115 let check_after = raw
116 .check_after_secs
117 .ok_or_else(|| D::Error::custom("Missing field: check_after_secs"));
118 Ok(match raw.state {
119 Pending => ProgressInfo::Pending(check_after?),
120 InProgress => ProgressInfo::InProgress(check_after?),
121 Success => ProgressInfo::Success,
122 Failed => {
123 let err = raw
124 .error
125 .ok_or_else(|| D::Error::custom("Missing field: error"))?;
126 ProgressInfo::Failed(err)
127 }
128 })
129 }
130}
131
132#[derive(Debug, Deserialize)]
134struct RawMedia {
135 #[serde(rename = "media_id_string")]
137 id: String,
138 #[serde(default)]
141 #[serde(rename = "expires_after_secs")]
142 expires_after: u64,
143 #[serde(rename = "processing_info")]
144 progress: Option<ProgressInfo>,
145}
146
147#[derive(Debug, Clone, derive_more::From)]
148pub struct MediaId(pub(crate) String);
150
151#[derive(Debug, Clone)]
153pub struct MediaHandle {
154 pub id: MediaId,
156 pub expires_at: Instant,
158 pub progress: Option<ProgressInfo>,
160}
161
162impl From<RawMedia> for MediaHandle {
163 fn from(raw: RawMedia) -> Self {
164 Self {
165 id: raw.id.into(),
166 expires_at: Instant::now() + Duration::from_secs(raw.expires_after),
169 progress: raw.progress,
170 }
171 }
172}
173
174impl MediaHandle {
175 pub fn is_valid(&self) -> bool {
179 Instant::now() < self.expires_at
180 }
181}
182
183#[derive(Debug, Copy, Clone, PartialEq, Eq, derive_more::Display)]
186enum MediaCategory {
187 #[display(fmt = "tweet_image")]
189 Image,
190 #[display(fmt = "tweet_gif")]
192 Gif,
193 #[display(fmt = "tweet_video")]
195 Video,
196}
197
198impl From<&mime::Mime> for MediaCategory {
199 fn from(mime: &mime::Mime) -> Self {
200 if mime == &media_types::image_gif() {
201 MediaCategory::Gif
202 } else if mime == &media_types::video_mp4() {
203 MediaCategory::Video
204 } else {
205 MediaCategory::Image
207 }
208 }
209}
210
211impl MediaCategory {
212 fn dm_category(&self) -> &'static str {
213 match self {
214 MediaCategory::Image => "dm_image",
215 MediaCategory::Gif => "dm_gif",
216 MediaCategory::Video => "dm_video",
217 }
218 }
219}
220
221pub async fn upload_media(
228 data: &[u8],
229 media_type: &mime::Mime,
230 token: &auth::Token,
231) -> error::Result<MediaHandle> {
232 let media_category = MediaCategory::from(media_type);
233 let params = ParamList::new()
234 .add_param("command", "INIT")
235 .add_param("total_bytes", data.len().to_string())
236 .add_param("media_type", media_type.to_string())
237 .add_param("media_category", media_category.to_string());
238 let req = post(links::media::UPLOAD, token, Some(¶ms));
239
240 let media = request_with_json_response::<RawMedia>(req).await?.response;
241
242 finish_upload(media, data, token).await
243}
244
245pub async fn upload_media_for_dm(
263 data: &[u8],
264 media_type: &mime::Mime,
265 shared: bool,
266 token: &auth::Token,
267) -> error::Result<MediaHandle> {
268 let media_category = MediaCategory::from(media_type);
269 let params = ParamList::new()
270 .add_param("command", "INIT")
271 .add_param("total_bytes", data.len().to_string())
272 .add_param("media_type", media_type.to_string())
273 .add_param("media_category", media_category.dm_category())
274 .add_param("shared", shared.to_string());
275 let req = post(links::media::UPLOAD, token, Some(¶ms));
276
277 let media = request_with_json_response::<RawMedia>(req).await?.response;
278
279 finish_upload(media, data, token).await
280}
281
282async fn finish_upload(
283 media: RawMedia,
284 data: &[u8],
285 token: &auth::Token,
286) -> error::Result<MediaHandle> {
287 for (ix, chunk) in data.chunks(1024 * 1024).enumerate() {
289 let params = ParamList::new()
290 .add_param("command", "APPEND")
291 .add_param("media_id", media.id.clone())
292 .add_param("media_data", base64::encode(chunk))
293 .add_param("segment_index", ix.to_string());
294 let req = post(links::media::UPLOAD, token, Some(¶ms));
295 raw_request(req).await?;
297 }
298
299 let params = ParamList::new()
300 .add_param("command", "FINALIZE")
301 .add_param("media_id", media.id.clone());
302 let req = post(links::media::UPLOAD, token, Some(¶ms));
303 Ok(request_with_json_response::<RawMedia>(req)
304 .await?
305 .response
306 .into())
307}
308
309pub async fn get_status(media_id: MediaId, token: &auth::Token) -> error::Result<MediaHandle> {
311 let params = ParamList::new()
312 .add_param("command", "STATUS")
313 .add_param("media_id", media_id.0);
314 let req = get(links::media::UPLOAD, token, Some(¶ms));
315 Ok(request_with_json_response::<RawMedia>(req)
316 .await?
317 .response
318 .into())
319}
320
321pub async fn set_metadata(
324 media_id: &MediaId,
325 alt_text: &str,
326 token: &auth::Token,
327) -> error::Result<()> {
328 let payload = serde_json::json!({
329 "media_id": media_id.0,
330 "alt_text": {
331 "text": alt_text
332 }
333 });
334 let req = post_json(links::media::METADATA, token, payload);
335 raw_request(req).await?;
336 Ok(())
337}
338
339#[cfg(test)]
340mod tests {
341 use super::RawMedia;
342 use crate::common::tests::load_file;
343
344 fn load_media(path: &str) -> RawMedia {
345 let content = load_file(path);
346 ::serde_json::from_str::<RawMedia>(&content).unwrap()
347 }
348
349 #[test]
350 fn parse_media() {
351 let media = load_media("sample_payloads/media.json");
352
353 assert_eq!(media.id, "710511363345354753");
354 assert_eq!(media.expires_after, 86400);
355 }
356
357 #[test]
358 fn parse_media_pending() {
359 let media = load_media("sample_payloads/media_pending.json");
360
361 assert_eq!(media.id, "13");
362 assert_eq!(media.expires_after, 86400);
363 assert!(media.progress.is_some());
364
365 match media.progress {
366 Some(super::ProgressInfo::Pending(5)) => (),
367 other => assert!(false, "Unexpected value of progress={:?}", other),
368 }
369 }
370
371 #[test]
372 fn parse_media_in_progress() {
373 let media = load_media("sample_payloads/media_in_progress.json");
374
375 assert_eq!(media.id, "13");
376 assert_eq!(media.expires_after, 3595);
377 assert!(media.progress.is_some());
378
379 match media.progress {
380 Some(super::ProgressInfo::InProgress(10)) => (),
381 other => assert!(false, "Unexpected value of progress={:?}", other),
382 }
383 }
384
385 #[test]
386 fn parse_media_fail() {
387 let media = load_media("sample_payloads/media_fail.json");
388
389 assert_eq!(media.id, "710511363345354753");
390 assert_eq!(media.expires_after, 0);
391 assert!(media.progress.is_some());
392
393 match media.progress {
394 Some(super::ProgressInfo::Failed(error)) => assert_eq!(
395 error,
396 crate::error::MediaError {
397 code: 1,
398 name: "InvalidMedia".to_string(),
399 message: "Unsupported video format".to_string(),
400 }
401 ),
402 other => assert!(false, "Unexpected value of progress={:?}", other),
403 }
404 }
405}