use crate::auth::TokenManager;
use crate::error::{Result, WeChatError};
use crate::http::{DraftResponse, MaterialUploadResponse, WeChatHttpClient, WeChatResponse};
use crate::markdown::ImageRef;
use blake3;
use futures::future::try_join_all;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::Path;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::fs;
use tokio::sync::{RwLock, Semaphore};
use tracing::{debug, info, warn};
const MAX_CONCURRENT_UPLOADS: usize = 5;
const MATERIAL_CACHE_TTL: Duration = Duration::from_secs(300);
const MAX_CACHE_SIZE: usize = 1000;
#[derive(Debug, Clone)]
struct CachedMaterial {
material: MaterialItem,
cached_at: Instant,
}
impl CachedMaterial {
fn new(material: MaterialItem) -> Self {
Self {
material,
cached_at: Instant::now(),
}
}
fn is_expired(&self) -> bool {
self.cached_at.elapsed() > MATERIAL_CACHE_TTL
}
}
const MAX_IMAGE_SIZE: u64 = 10 * 1024 * 1024;
const MAX_DOWNLOAD_SIZE: u64 = 50 * 1024 * 1024;
#[derive(Debug, Clone)]
pub struct UploadResult {
pub image_ref: ImageRef,
pub media_id: String,
pub url: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Article {
pub title: String,
pub author: String,
pub content: String,
pub content_source_url: Option<String>,
pub digest: String,
pub show_cover_pic: u8,
pub thumb_media_id: Option<String>,
pub need_open_comment: u8,
pub only_fans_can_comment: u8,
}
impl Article {
pub fn new(title: String, author: String, content: String) -> Self {
Self {
title,
author,
content,
content_source_url: None,
digest: String::new(),
show_cover_pic: 1,
thumb_media_id: None,
need_open_comment: 0,
only_fans_can_comment: 0,
}
}
pub fn with_digest(mut self, digest: String) -> Self {
self.digest = digest;
self
}
pub fn with_cover_image(mut self, thumb_media_id: String) -> Self {
self.thumb_media_id = Some(thumb_media_id);
self
}
pub fn with_show_cover(mut self, show: bool) -> Self {
self.show_cover_pic = if show { 1 } else { 0 };
self
}
pub fn with_comments(mut self, enable_comments: bool, fans_only: bool) -> Self {
self.need_open_comment = if enable_comments { 1 } else { 0 };
self.only_fans_can_comment = if fans_only { 1 } else { 0 };
self
}
pub fn with_source_url(mut self, url: String) -> Self {
self.content_source_url = Some(url);
self
}
}
#[derive(Debug, Serialize)]
struct DraftRequest {
articles: Vec<Article>,
}
#[derive(Debug, Deserialize)]
pub struct DraftInfo {
pub media_id: String,
pub content: DraftContent,
pub update_time: u64,
}
#[derive(Debug, Deserialize)]
pub struct DraftContent {
pub news_item: Vec<Article>,
}
#[derive(Debug, Deserialize)]
pub struct DraftListResponse {
pub total_count: u32,
pub item_count: u32,
pub item: Vec<DraftInfo>,
}
#[derive(Debug, Deserialize, Clone)]
pub struct MaterialItem {
pub media_id: String,
pub name: String,
pub update_time: u64,
pub url: String,
}
#[derive(Debug, Deserialize)]
pub struct MaterialListResponse {
pub total_count: u32,
pub item_count: u32,
pub item: Vec<MaterialItem>,
}
#[derive(Debug)]
pub struct ImageUploader {
http_client: Arc<WeChatHttpClient>,
token_manager: Arc<TokenManager>,
semaphore: Arc<Semaphore>,
material_cache: Arc<RwLock<HashMap<String, CachedMaterial>>>,
}
impl ImageUploader {
pub fn new(http_client: Arc<WeChatHttpClient>, token_manager: Arc<TokenManager>) -> Self {
Self {
http_client,
token_manager,
semaphore: Arc::new(Semaphore::new(MAX_CONCURRENT_UPLOADS)),
material_cache: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn upload_images(
&self,
images: Vec<ImageRef>,
base_path: &Path,
) -> Result<Vec<UploadResult>> {
if images.is_empty() {
return Ok(Vec::new());
}
debug!("Uploading {} images concurrently", images.len());
let tasks: Vec<_> = images
.into_iter()
.map(|image_ref| {
let uploader = self.clone();
let base_path = base_path.to_owned();
tokio::spawn(
async move { uploader.upload_single_image(image_ref, &base_path).await },
)
})
.collect();
let results = try_join_all(tasks)
.await
.map_err(|e| WeChatError::Internal {
message: format!("Task join error: {e}"),
})?;
let upload_results: Result<Vec<_>> = results.into_iter().collect();
let uploads = upload_results?;
info!("Successfully uploaded {} images", uploads.len());
Ok(uploads)
}
async fn upload_single_image(
&self,
image_ref: ImageRef,
base_path: &Path,
) -> Result<UploadResult> {
let _permit = self
.semaphore
.acquire()
.await
.map_err(|e| WeChatError::Internal {
message: format!("Semaphore error: {e}"),
})?;
debug!("Processing image: {}", image_ref.original_url);
let image_data = if image_ref.is_local {
let image_path = image_ref.resolve_path(base_path)?;
self.load_local_image(&image_path).await?
} else {
self.download_remote_image(&image_ref.original_url).await?
};
let (media_id, url) = self
.upload_image_as_material(image_data, &image_ref.original_url)
.await?;
info!(
"Successfully uploaded image: {} -> {} (media_id: {})",
image_ref.original_url, url, media_id
);
Ok(UploadResult {
image_ref,
media_id,
url,
})
}
async fn upload_image_as_material(
&self,
image_data: Vec<u8>,
original_path: &str,
) -> Result<(String, String)> {
let hash = blake3::hash(&image_data);
let hash_str = hash.to_hex().to_string();
debug!("Image hash: {hash_str}");
{
let cache = self.material_cache.read().await;
if let Some(cached) = cache.get(&hash_str) {
if !cached.is_expired() {
debug!("Cache hit for hash: {hash_str}");
return Ok((
cached.material.media_id.clone(),
cached.material.url.clone(),
));
} else {
debug!("Cache entry expired for hash: {hash_str}");
}
}
}
debug!("Checking for existing material with hash: {}", hash_str);
if let Some((existing_url, media_id)) = self.find_material_by_hash(&hash_str).await? {
info!("Image already exists with hash {hash_str}, reusing media_id: {media_id}");
{
let mut cache = self.material_cache.write().await;
let material_item = MaterialItem {
media_id: media_id.clone(),
name: hash_str.clone(),
update_time: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs(),
url: existing_url.clone(),
};
cache.insert(hash_str.clone(), CachedMaterial::new(material_item));
debug!("Cached found material for hash: {hash_str}");
}
return Ok((media_id, existing_url));
}
let extension = self.get_image_extension(original_path, &image_data);
let filename = format!("{hash_str}.{extension}");
debug!("Uploading new image as permanent material with filename: {filename}");
let access_token = self.token_manager.get_access_token().await?;
let response = self
.http_client
.upload_material(&access_token, "image", image_data, &filename)
.await?;
let response_text = response.text().await?;
let material = if let Ok(direct_response) =
serde_json::from_str::<MaterialUploadResponse>(&response_text)
{
direct_response
} else {
let upload_response: WeChatResponse<MaterialUploadResponse> =
serde_json::from_str(&response_text)?;
upload_response.into_result()?
};
info!(
"Successfully uploaded new material: {} -> media_id: {} (hash: {})",
original_path, material.media_id, hash_str
);
{
let mut cache = self.material_cache.write().await;
if cache.len() >= MAX_CACHE_SIZE {
let remove_count = MAX_CACHE_SIZE / 10;
let mut to_remove = Vec::with_capacity(remove_count);
for (hash, cached) in cache.iter() {
to_remove.push((hash.clone(), cached.cached_at));
if to_remove.len() >= remove_count {
break;
}
}
to_remove.sort_by_key(|(_, timestamp)| *timestamp);
for (hash, _) in to_remove.into_iter().take(remove_count) {
cache.remove(&hash);
}
debug!("Evicted {} old cache entries", remove_count);
}
let material_item = MaterialItem {
media_id: material.media_id.clone(),
name: hash_str.clone(),
update_time: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs(),
url: material.url.clone(),
};
cache.insert(hash_str.clone(), CachedMaterial::new(material_item));
debug!("Cached material for hash: {hash_str}");
}
Ok((material.media_id, material.url))
}
pub async fn clear_expired_cache(&self) {
let mut cache = self.material_cache.write().await;
let initial_size = cache.len();
cache.retain(|hash, cached| {
let keep = !cached.is_expired();
if !keep {
debug!("Removing expired cache entry: {}", hash);
}
keep
});
let removed = initial_size - cache.len();
if removed > 0 {
info!("Cleared {} expired cache entries", removed);
}
}
pub async fn get_cache_stats(&self) -> (usize, usize) {
let cache = self.material_cache.read().await;
let total = cache.len();
let expired = cache.values().filter(|c| c.is_expired()).count();
(total, expired)
}
async fn load_local_image(&self, path: &Path) -> Result<Vec<u8>> {
let metadata = fs::metadata(path)
.await
.map_err(|e| WeChatError::ImageUpload {
path: path.display().to_string(),
reason: format!("Failed to get file metadata: {e}"),
})?;
let file_size = metadata.len();
if file_size > MAX_IMAGE_SIZE {
return Err(WeChatError::ImageUpload {
path: path.display().to_string(),
reason: format!("File too large: {file_size} bytes (max: {MAX_IMAGE_SIZE} bytes)"),
});
}
debug!(
"Loading local image: {} ({} bytes)",
path.display(),
file_size
);
fs::read(path).await.map_err(|e| WeChatError::ImageUpload {
path: path.display().to_string(),
reason: format!("Failed to read local file: {e}"),
})
}
async fn download_remote_image(&self, url: &str) -> Result<Vec<u8>> {
debug!("Downloading remote image: {url}");
self.http_client
.download_with_limit(url, MAX_DOWNLOAD_SIZE)
.await
.map_err(|e| WeChatError::ImageUpload {
path: url.to_string(),
reason: format!("Failed to download remote image: {e}"),
})
}
fn get_image_extension(&self, url: &str, image_data: &[u8]) -> String {
if let Some(ext) = Path::new(url)
.extension()
.and_then(|e| e.to_str())
.filter(|e| matches!(*e, "jpg" | "jpeg" | "png" | "gif" | "bmp" | "webp"))
{
return ext.to_string();
}
if image_data.len() >= 4 {
match &image_data[0..4] {
[0xFF, 0xD8, 0xFF, _] => return "jpg".to_string(),
[0x89, 0x50, 0x4E, 0x47] => return "png".to_string(),
[0x47, 0x49, 0x46, _] => return "gif".to_string(),
[0x42, 0x4D, _, _] => return "bmp".to_string(),
_ => {}
}
}
if image_data.len() >= 12 && &image_data[0..4] == b"RIFF" && &image_data[8..12] == b"WEBP" {
return "webp".to_string();
}
"jpg".to_string()
}
async fn find_material_by_hash(&self, hash_str: &str) -> Result<Option<(String, String)>> {
debug!("Checking for existing material with hash: {hash_str}");
let access_token = self.token_manager.get_access_token().await?;
let request = serde_json::json!({
"type": "image",
"offset": 0,
"count": 20
});
let response = self
.http_client
.post_json_with_token(
"/cgi-bin/material/batchget_material",
&access_token,
&request,
)
.await
.map_err(|e| {
warn!("Failed to list materials: {e}");
e
});
let response = match response {
Ok(resp) => resp,
Err(_) => return Ok(None),
};
let response_text = response
.text()
.await
.unwrap_or_else(|_| "Unable to read response".to_string());
let materials_result =
serde_json::from_str::<WeChatResponse<MaterialListResponse>>(&response_text);
match materials_result {
Ok(materials_response) => {
if let Ok(material_list) = materials_response.into_result() {
for item in material_list.item {
if item.name.starts_with(hash_str) {
info!(
"Found existing material with hash {}: URL {} (media_id: {})",
hash_str, item.url, item.media_id
);
return Ok(Some((item.url, item.media_id)));
}
}
}
}
Err(e) => {
warn!("Failed to parse material list response: {e}");
}
}
debug!("No existing material found with hash: {hash_str}");
Ok(None)
}
pub async fn upload_cover_material(&self, cover_path: &Path) -> Result<String> {
info!(
"Uploading cover image as permanent material: {}",
cover_path.display()
);
let image_data = self.load_local_image(cover_path).await?;
let (media_id, _url) = self
.upload_image_as_material(image_data, &cover_path.to_string_lossy())
.await?;
info!(
"Successfully uploaded cover image: {} -> media_id: {}",
cover_path.display(),
media_id
);
Ok(media_id)
}
}
impl Clone for ImageUploader {
fn clone(&self) -> Self {
Self {
http_client: Arc::clone(&self.http_client),
token_manager: Arc::clone(&self.token_manager),
semaphore: Arc::clone(&self.semaphore),
material_cache: Arc::clone(&self.material_cache),
}
}
}
#[derive(Debug)]
pub struct DraftManager {
http_client: Arc<WeChatHttpClient>,
token_manager: Arc<TokenManager>,
}
impl DraftManager {
pub fn new(http_client: Arc<WeChatHttpClient>, token_manager: Arc<TokenManager>) -> Self {
Self {
http_client,
token_manager,
}
}
pub async fn create_draft(&self, articles: Vec<Article>) -> Result<String> {
if articles.is_empty() {
return Err(WeChatError::config_error(
"At least one article is required",
));
}
let title = &articles[0].title;
info!("Processing draft with title: {title}");
if let Some(existing_media_id) = self.find_draft_by_title(title).await? {
info!(
"Found existing draft with title '{title}', updating media_id: {existing_media_id}"
);
self.update_draft(&existing_media_id, articles).await?;
return Ok(existing_media_id);
}
info!("No existing draft found, creating new draft");
let request = DraftRequest { articles };
let access_token = self.token_manager.get_access_token().await?;
let response = self
.http_client
.post_json_with_token("/cgi-bin/draft/add", &access_token, &request)
.await?;
let draft_response: WeChatResponse<DraftResponse> = response.json().await?;
let draft = draft_response.into_result()?;
info!(
"Successfully created new draft with media_id: {}",
draft.media_id
);
Ok(draft.media_id)
}
pub async fn get_draft(&self, media_id: &str) -> Result<DraftInfo> {
debug!("Getting draft: {media_id}");
let access_token = self.token_manager.get_access_token().await?;
let request = serde_json::json!({ "media_id": media_id });
let response = self
.http_client
.post_json_with_token("/cgi-bin/draft/get", &access_token, &request)
.await?;
let draft_response: WeChatResponse<DraftInfo> = response.json().await?;
draft_response.into_result()
}
pub async fn update_draft(&self, media_id: &str, articles: Vec<Article>) -> Result<()> {
if articles.is_empty() {
return Err(WeChatError::config_error(
"At least one article is required",
));
}
info!(
"Updating draft {} with {} articles",
media_id,
articles.len()
);
let request = serde_json::json!({
"media_id": media_id,
"index": 0,
"articles": articles[0] });
let access_token = self.token_manager.get_access_token().await?;
let response = self
.http_client
.post_json_with_token("/cgi-bin/draft/update", &access_token, &request)
.await?;
let update_response: WeChatResponse<serde_json::Value> = response.json().await?;
update_response.into_result()?;
info!("Successfully updated draft: {media_id}");
Ok(())
}
pub async fn delete_draft(&self, media_id: &str) -> Result<()> {
info!("Deleting draft: {media_id}");
let request = serde_json::json!({ "media_id": media_id });
let access_token = self.token_manager.get_access_token().await?;
let response = self
.http_client
.post_json_with_token("/cgi-bin/draft/delete", &access_token, &request)
.await?;
let delete_response: WeChatResponse<serde_json::Value> = response.json().await?;
delete_response.into_result()?;
info!("Successfully deleted draft: {media_id}");
Ok(())
}
pub async fn list_drafts(&self, offset: u32, count: u32) -> Result<Vec<DraftInfo>> {
debug!("Listing drafts: offset={offset}, count={count}");
let request = serde_json::json!({
"offset": offset,
"count": count,
"no_content": 0
});
let access_token = self.token_manager.get_access_token().await?;
let response = self
.http_client
.post_json_with_token("/cgi-bin/draft/batchget", &access_token, &request)
.await?;
let response_text = response.text().await?;
let list_response: WeChatResponse<DraftListResponse> =
serde_json::from_str(&response_text)?;
let drafts = list_response.into_result()?;
Ok(drafts.item)
}
pub fn create_url_mapping(&self, upload_results: &[UploadResult]) -> HashMap<String, String> {
upload_results
.iter()
.map(|result| (result.image_ref.original_url.clone(), result.url.clone()))
.collect()
}
async fn find_draft_by_title(&self, title: &str) -> Result<Option<String>> {
debug!("Searching for draft with title: {title}");
let drafts = match self.list_drafts(0, 20).await {
Ok(drafts) => drafts,
Err(e) => {
warn!("Failed to list drafts: {e}");
return Ok(None);
}
};
for draft in drafts {
if let Some(first_article) = draft.content.news_item.first() {
if first_article.title == title {
info!("Found existing draft with matching title");
return Ok(Some(draft.media_id));
}
}
}
debug!("No draft found with title: {title}");
Ok(None)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::auth::TokenManager;
use std::sync::Arc;
#[tokio::test]
async fn test_article_creation() {
let article = Article::new(
"Test Title".to_string(),
"Test Author".to_string(),
"<h1>Test Content</h1>".to_string(),
);
assert_eq!(article.title, "Test Title");
assert_eq!(article.author, "Test Author");
assert_eq!(article.content, "<h1>Test Content</h1>");
assert_eq!(article.show_cover_pic, 1);
assert_eq!(article.need_open_comment, 0);
}
#[tokio::test]
async fn test_article_builder_methods() {
let article = Article::new(
"Title".to_string(),
"Author".to_string(),
"Content".to_string(),
)
.with_digest("Test digest".to_string())
.with_cover_image("cover_media_id".to_string())
.with_show_cover(false)
.with_comments(true, true)
.with_source_url("https://example.com".to_string());
assert_eq!(article.digest, "Test digest");
assert_eq!(article.thumb_media_id, Some("cover_media_id".to_string()));
assert_eq!(article.show_cover_pic, 0);
assert_eq!(article.need_open_comment, 1);
assert_eq!(article.only_fans_can_comment, 1);
assert_eq!(
article.content_source_url,
Some("https://example.com".to_string())
);
}
#[tokio::test]
async fn test_image_uploader_creation() {
let http_client = Arc::new(WeChatHttpClient::new().unwrap());
let token_manager = Arc::new(TokenManager::new(
"test_app_id",
"test_secret",
Arc::clone(&http_client),
));
let uploader = ImageUploader::new(http_client, token_manager);
assert_eq!(
uploader.semaphore.available_permits(),
MAX_CONCURRENT_UPLOADS
);
}
#[tokio::test]
async fn test_draft_manager_creation() {
let http_client = Arc::new(WeChatHttpClient::new().unwrap());
let token_manager = Arc::new(TokenManager::new(
"test_app_id",
"test_secret",
Arc::clone(&http_client),
));
let _manager = DraftManager::new(http_client, token_manager);
}
#[test]
fn test_image_extension_detection() {
let http_client = Arc::new(WeChatHttpClient::new().unwrap());
let token_manager = Arc::new(TokenManager::new(
"test_app_id",
"test_secret",
Arc::clone(&http_client),
));
let uploader = ImageUploader::new(http_client, token_manager);
assert_eq!(uploader.get_image_extension("test.jpg", &[]), "jpg");
assert_eq!(uploader.get_image_extension("test.png", &[]), "png");
assert_eq!(uploader.get_image_extension("test.webp", &[]), "webp");
let jpeg_header = vec![0xFF, 0xD8, 0xFF, 0xE0];
assert_eq!(uploader.get_image_extension("noext", &jpeg_header), "jpg");
let png_header = vec![0x89, 0x50, 0x4E, 0x47];
assert_eq!(uploader.get_image_extension("noext", &png_header), "png");
}
#[test]
fn test_url_mapping_creation() {
let http_client = Arc::new(WeChatHttpClient::new().unwrap());
let token_manager = Arc::new(TokenManager::new(
"test_app_id",
"test_secret",
Arc::clone(&http_client),
));
let manager = DraftManager::new(http_client, token_manager);
let image_ref = ImageRef::new("Alt".to_string(), "./test.jpg".to_string(), (0, 10));
let upload_result = UploadResult {
image_ref,
media_id: "media123".to_string(),
url: "https://wechat.com/image123".to_string(),
};
let mapping = manager.create_url_mapping(&[upload_result]);
assert_eq!(
mapping.get("./test.jpg"),
Some(&"https://wechat.com/image123".to_string())
);
}
}