use std::sync::Arc;
use serde::{Deserialize, Serialize};
use crate::error::{HttpError, WechatError};
use super::{WechatApi, WechatContext};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum MediaType {
Image,
Voice,
Video,
Thumb,
}
impl MediaType {
pub fn as_str(&self) -> &'static str {
match self {
MediaType::Image => "image",
MediaType::Voice => "voice",
MediaType::Video => "video",
MediaType::Thumb => "thumb",
}
}
}
#[non_exhaustive]
#[derive(Debug, Clone, Deserialize)]
pub struct MediaUploadResponse {
#[serde(rename = "type")]
pub media_type: String,
pub media_id: String,
pub created_at: i64,
#[serde(default)]
pub(crate) errcode: i32,
#[serde(default)]
pub(crate) errmsg: String,
}
impl MediaUploadResponse {
pub fn errcode(&self) -> i32 {
self.errcode
}
pub fn errmsg(&self) -> &str {
&self.errmsg
}
}
pub struct MediaApi {
context: Arc<WechatContext>,
}
impl MediaApi {
pub fn new(context: Arc<WechatContext>) -> Self {
Self { context }
}
pub async fn upload_temp_media(
&self,
media_type: MediaType,
filename: &str,
data: &[u8],
) -> Result<MediaUploadResponse, WechatError> {
let access_token = self.context.token_manager.get_token().await?;
let url = format!(
"{}{}",
self.context.client.base_url(),
"/cgi-bin/media/upload"
);
let query = [
("access_token", access_token.as_str()),
("type", media_type.as_str()),
];
let part = reqwest::multipart::Part::bytes(data.to_vec()).file_name(filename.to_string());
let form = reqwest::multipart::Form::new().part("media", part);
let request = self
.context
.client
.http()
.post(&url)
.query(&query)
.multipart(form)
.build()?;
let response = self.context.client.send_request(request).await?;
if let Err(error) = response.error_for_status_ref() {
return Err(error.into());
}
let value: serde_json::Value = response.json().await?;
if let Some((code, message)) = parse_api_error_from_json_value(&value) {
return Err(WechatError::Api { code, message });
}
let result: MediaUploadResponse = serde_json::from_value(value)
.map_err(|error| WechatError::Http(HttpError::Decode(error.to_string())))?;
WechatError::check_api(result.errcode(), result.errmsg())?;
Ok(result)
}
pub async fn get_temp_media(&self, media_id: &str) -> Result<Vec<u8>, WechatError> {
let access_token = self.context.token_manager.get_token().await?;
let url = format!("{}{}", self.context.client.base_url(), "/cgi-bin/media/get");
let query = [
("access_token", access_token.as_str()),
("media_id", media_id),
];
let request = self.context.client.http().get(&url).query(&query).build()?;
let response = self.context.client.send_request(request).await?;
if let Err(error) = response.error_for_status_ref() {
return Err(error.into());
}
let bytes = response.bytes().await?;
if let Some((code, message)) = parse_api_error_from_json_bytes(&bytes) {
return Err(WechatError::Api { code, message });
}
Ok(bytes.to_vec())
}
}
fn parse_api_error_from_json_bytes(bytes: &[u8]) -> Option<(i32, String)> {
let value: serde_json::Value = serde_json::from_slice(bytes).ok()?;
parse_api_error_from_json_value(&value)
}
fn parse_api_error_from_json_value(value: &serde_json::Value) -> Option<(i32, String)> {
let raw_code = value.get("errcode")?.as_i64()?;
if raw_code == 0 {
return None;
}
let code = i32::try_from(raw_code).unwrap_or_else(|_| {
if raw_code.is_negative() {
i32::MIN
} else {
i32::MAX
}
});
let message = value
.get("errmsg")
.and_then(|v| v.as_str())
.unwrap_or("unknown error")
.to_string();
Some((code, message))
}
impl WechatApi for MediaApi {
fn api_name(&self) -> &'static str {
"media"
}
fn context(&self) -> &WechatContext {
&self.context
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{AppId, AppSecret};
use crate::WechatClient;
use wiremock::matchers::{method, path, query_param};
use wiremock::{Mock, MockServer, ResponseTemplate};
fn create_test_context(base_url: &str) -> Arc<WechatContext> {
let appid = AppId::new("wx1234567890abcdef").unwrap();
let secret = AppSecret::new("secret1234567890ab").unwrap();
let client = Arc::new(
WechatClient::builder()
.appid(appid)
.secret(secret)
.base_url(base_url)
.build()
.unwrap(),
);
let token_manager = Arc::new(crate::token::TokenManager::new((*client).clone()));
Arc::new(WechatContext::new(client, token_manager))
}
#[test]
fn test_media_type() {
assert_eq!(MediaType::Image.as_str(), "image");
assert_eq!(MediaType::Voice.as_str(), "voice");
assert_eq!(MediaType::Video.as_str(), "video");
assert_eq!(MediaType::Thumb.as_str(), "thumb");
}
#[tokio::test]
async fn test_upload_temp_media_success() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/cgi-bin/token"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"access_token": "test_token",
"expires_in": 7200,
"errcode": 0,
"errmsg": ""
})))
.mount(&mock_server)
.await;
Mock::given(method("POST"))
.and(path("/cgi-bin/media/upload"))
.and(query_param("access_token", "test_token"))
.and(query_param("type", "image"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"type": "image",
"media_id": "test_media_id_123",
"created_at": 1234567890,
"errcode": 0,
"errmsg": ""
})))
.mount(&mock_server)
.await;
let context = create_test_context(&mock_server.uri());
let media_api = MediaApi::new(context);
let image_data = b"fake_image_data";
let result = media_api
.upload_temp_media(MediaType::Image, "test.jpg", image_data)
.await;
assert!(result.is_ok());
let response = result.unwrap();
assert_eq!(response.media_type, "image");
assert_eq!(response.media_id, "test_media_id_123");
assert_eq!(response.created_at, 1234567890);
}
#[tokio::test]
async fn test_upload_temp_media_api_error() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/cgi-bin/token"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"access_token": "test_token",
"expires_in": 7200,
"errcode": 0,
"errmsg": ""
})))
.mount(&mock_server)
.await;
Mock::given(method("POST"))
.and(path("/cgi-bin/media/upload"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"type": "",
"media_id": "",
"created_at": 0,
"errcode": 40001,
"errmsg": "invalid credential"
})))
.mount(&mock_server)
.await;
let context = create_test_context(&mock_server.uri());
let media_api = MediaApi::new(context);
let image_data = b"fake_image_data";
let result = media_api
.upload_temp_media(MediaType::Image, "test.jpg", image_data)
.await;
assert!(result.is_err());
if let Err(WechatError::Api { code, message }) = result {
assert_eq!(code, 40001);
assert_eq!(message, "invalid credential");
} else {
panic!("Expected Api error");
}
}
#[tokio::test]
async fn test_get_temp_media_success() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/cgi-bin/token"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"access_token": "test_token",
"expires_in": 7200,
"errcode": 0,
"errmsg": ""
})))
.mount(&mock_server)
.await;
Mock::given(method("GET"))
.and(path("/cgi-bin/media/get"))
.and(query_param("access_token", "test_token"))
.and(query_param("media_id", "test_media_id"))
.respond_with(
ResponseTemplate::new(200).set_body_raw(b"media_binary_data", "image/jpeg"),
)
.mount(&mock_server)
.await;
let context = create_test_context(&mock_server.uri());
let media_api = MediaApi::new(context);
let result = media_api.get_temp_media("test_media_id").await;
assert!(result.is_ok());
let data = result.unwrap();
assert_eq!(data, b"media_binary_data");
}
#[tokio::test]
async fn test_get_temp_media_error_json() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/cgi-bin/token"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"access_token": "test_token",
"expires_in": 7200,
"errcode": 0,
"errmsg": ""
})))
.mount(&mock_server)
.await;
Mock::given(method("GET"))
.and(path("/cgi-bin/media/get"))
.and(query_param("access_token", "test_token"))
.and(query_param("media_id", "expired_media"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"errcode": 40007,
"errmsg": "invalid media_id"
})))
.mount(&mock_server)
.await;
let context = create_test_context(&mock_server.uri());
let media_api = MediaApi::new(context);
let result = media_api.get_temp_media("expired_media").await;
assert!(result.is_err());
match result {
Err(WechatError::Api { code, message }) => {
assert_eq!(code, 40007);
assert_eq!(message, "invalid media_id");
}
_ => panic!("Expected WechatError::Api"),
}
}
}