use crate::api::{ApiClient, ApiError};
use crate::commands::guards::{check_write_allowed, confirm_destructive_with_hint};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::collections::HashMap;
use std::path::Path;
#[derive(Debug, Deserialize)]
struct GetUploadUrlResponse {
ok: bool,
upload_url: Option<String>,
file_id: Option<String>,
error: Option<String>,
}
#[derive(Debug, Deserialize, Serialize)]
struct CompleteUploadResponse {
ok: bool,
#[serde(skip_serializing_if = "Option::is_none")]
files: Option<Vec<serde_json::Value>>,
#[serde(skip_serializing_if = "Option::is_none")]
error: Option<String>,
}
pub async fn file_upload(
client: &ApiClient,
file_path: String,
channels: Option<String>,
title: Option<String>,
comment: Option<String>,
yes: bool,
non_interactive: bool,
) -> Result<serde_json::Value, ApiError> {
check_write_allowed()?;
let hint = format!("Example: slack-rs file upload {} --yes", file_path);
confirm_destructive_with_hint(yes, "upload this file", non_interactive, Some(&hint))?;
let path = Path::new(&file_path);
if !path.exists() {
return Err(ApiError::SlackError(format!(
"File not found: {}",
file_path
)));
}
let file_bytes = std::fs::read(path)
.map_err(|e| ApiError::SlackError(format!("Failed to read file {}: {}", file_path, e)))?;
let file_name = path
.file_name()
.and_then(|n| n.to_str())
.unwrap_or("file")
.to_string();
let file_length = file_bytes.len();
let file_length_str = file_length.to_string();
let form_params = vec![
("filename", file_name.as_str()),
("length", file_length_str.as_str()),
];
let url = format!("{}/files.getUploadURLExternal", client.base_url());
let token = client
.token
.as_ref()
.ok_or_else(|| ApiError::SlackError("No token configured".to_string()))?;
let http_client = Client::new();
let get_url_response = http_client
.post(&url)
.bearer_auth(token)
.form(&form_params)
.send()
.await
.map_err(|e| ApiError::SlackError(format!("Failed to get upload URL: {}", e)))?;
let get_url_result: GetUploadUrlResponse = get_url_response
.json()
.await
.map_err(|e| ApiError::SlackError(format!("Failed to parse upload URL response: {}", e)))?;
if !get_url_result.ok {
return Err(ApiError::SlackError(format!(
"files.getUploadURLExternal failed: {}",
get_url_result
.error
.unwrap_or_else(|| "Unknown error".to_string())
)));
}
let upload_url = get_url_result
.upload_url
.ok_or_else(|| ApiError::SlackError("No upload_url in response".to_string()))?;
let file_id = get_url_result
.file_id
.ok_or_else(|| ApiError::SlackError("No file_id in response".to_string()))?;
let upload_response = http_client
.post(&upload_url)
.header("Content-Type", "application/octet-stream")
.body(file_bytes)
.send()
.await
.map_err(|e| ApiError::SlackError(format!("Failed to upload file: {}", e)))?;
if !upload_response.status().is_success() {
return Err(ApiError::SlackError(format!(
"File upload failed with status: {}",
upload_response.status()
)));
}
let mut complete_params = HashMap::new();
let file_upload = json!({
"id": file_id,
"title": title.unwrap_or_else(|| file_name.clone())
});
complete_params.insert("files".to_string(), json!([file_upload]));
if let Some(ch) = channels {
complete_params.insert("channel_id".to_string(), json!(ch));
}
if let Some(cmt) = comment {
complete_params.insert("initial_comment".to_string(), json!(cmt));
}
let complete_url = format!("{}/files.completeUploadExternal", client.base_url());
let complete_response = http_client
.post(&complete_url)
.bearer_auth(token)
.json(&complete_params)
.send()
.await
.map_err(|e| ApiError::SlackError(format!("Failed to complete upload: {}", e)))?;
let complete_result: CompleteUploadResponse = complete_response
.json()
.await
.map_err(|e| ApiError::SlackError(format!("Failed to parse complete response: {}", e)))?;
if !complete_result.ok {
return Err(ApiError::SlackError(format!(
"files.completeUploadExternal failed: {}",
complete_result
.error
.unwrap_or_else(|| "Unknown error".to_string())
)));
}
serde_json::to_value(complete_result)
.map_err(|e| ApiError::SlackError(format!("Failed to serialize result: {}", e)))
}
#[derive(Debug, Deserialize)]
struct FilesInfoResponse {
ok: bool,
file: Option<FileInfo>,
error: Option<String>,
}
#[derive(Debug, Deserialize)]
struct FileInfo {
#[serde(default)]
name: Option<String>,
#[serde(default)]
url_private_download: Option<String>,
#[serde(default)]
url_private: Option<String>,
}
pub async fn file_download(
client: &ApiClient,
file_id: Option<String>,
url: Option<String>,
out: Option<String>,
) -> Result<serde_json::Value, ApiError> {
let http_client = Client::new();
let token = client
.token
.as_ref()
.ok_or_else(|| ApiError::SlackError("No token configured".to_string()))?;
let (download_url, filename_hint) = if let Some(fid) = file_id {
let info_url = format!("{}/files.info", client.base_url());
let form_params = vec![("file".to_string(), fid.clone())];
let info_response = http_client
.post(&info_url)
.bearer_auth(token)
.form(&form_params)
.send()
.await
.map_err(|e| ApiError::SlackError(format!("Failed to call files.info: {}", e)))?;
let info_result: FilesInfoResponse = info_response.json().await.map_err(|e| {
ApiError::SlackError(format!("Failed to parse files.info response: {}", e))
})?;
if !info_result.ok {
return Err(ApiError::SlackError(format!(
"files.info failed: {}",
info_result
.error
.unwrap_or_else(|| "Unknown error".to_string())
)));
}
let file = info_result.file.ok_or_else(|| {
ApiError::SlackError("No file information in files.info response".to_string())
})?;
let url = file
.url_private_download
.or(file.url_private)
.ok_or_else(|| {
ApiError::SlackError("No download URL found in file info".to_string())
})?;
let name = file.name.unwrap_or_else(|| format!("file-{}", fid));
(url, name)
} else if let Some(direct_url) = url {
let name = direct_url
.rsplit('/')
.next()
.unwrap_or("downloaded-file")
.to_string();
(direct_url, name)
} else {
return Err(ApiError::SlackError(
"Either file_id or url must be provided".to_string(),
));
};
let mut current_url = download_url.clone();
let mut redirect_count = 0;
const MAX_REDIRECTS: u8 = 10;
let download_response = loop {
let no_redirect_client = reqwest::ClientBuilder::new()
.redirect(reqwest::redirect::Policy::none())
.build()
.map_err(|e| ApiError::SlackError(format!("Failed to build HTTP client: {}", e)))?;
let response = no_redirect_client
.get(¤t_url)
.bearer_auth(token)
.send()
.await
.map_err(|e| ApiError::SlackError(format!("Failed to download file: {}", e)))?;
let status = response.status();
if status.is_redirection() {
if redirect_count >= MAX_REDIRECTS {
return Err(ApiError::SlackError(format!(
"Too many redirects (max {})",
MAX_REDIRECTS
)));
}
let location = response
.headers()
.get("location")
.and_then(|h| h.to_str().ok())
.ok_or_else(|| {
ApiError::SlackError(format!(
"Redirect response {} missing Location header",
status
))
})?;
current_url = if location.starts_with("http://") || location.starts_with("https://") {
location.to_string()
} else {
let base = reqwest::Url::parse(¤t_url).map_err(|e| {
ApiError::SlackError(format!("Failed to parse URL {}: {}", current_url, e))
})?;
base.join(location)
.map_err(|e| {
ApiError::SlackError(format!(
"Failed to join URLs {} + {}: {}",
current_url, location, e
))
})?
.to_string()
};
redirect_count += 1;
continue;
}
break response;
};
let is_html = download_response
.headers()
.get("content-type")
.and_then(|ct| ct.to_str().ok())
.map(|ct_str| ct_str.contains("text/html"))
.unwrap_or(false);
if is_html {
let status = download_response.status();
let body_bytes = download_response
.bytes()
.await
.map_err(|e| ApiError::SlackError(format!("Failed to read HTML response: {}", e)))?;
let body_str = String::from_utf8_lossy(&body_bytes);
let snippet = truncate_safely(&body_str, 200);
return Err(ApiError::SlackError(format!(
"Download returned HTML instead of file (status: {}). Possible causes:\n\
- Wrong URL: Make sure to use url_private_download, not permalink\n\
- Missing authentication: Token may lack required scopes\n\
- Invalid or expired file\n\
\n\
Response snippet:\n{}",
status, snippet
)));
}
if !download_response.status().is_success() {
return Err(ApiError::SlackError(format!(
"Download failed with status: {}",
download_response.status()
)));
}
let bytes = download_response
.bytes()
.await
.map_err(|e| ApiError::SlackError(format!("Failed to read response body: {}", e)))?;
let output_path = match out.as_deref() {
Some("-") => {
use std::io::Write;
std::io::stdout()
.write_all(&bytes)
.map_err(|e| ApiError::SlackError(format!("Failed to write to stdout: {}", e)))?;
"-".to_string()
}
Some(path) => {
let target_path = if Path::new(path).is_dir() {
Path::new(path).join(sanitize_filename(&filename_hint))
} else {
Path::new(path).to_path_buf()
};
std::fs::write(&target_path, &bytes).map_err(|e| {
ApiError::SlackError(format!(
"Failed to write file to {}: {}",
target_path.display(),
e
))
})?;
target_path.display().to_string()
}
None => {
let target_path = Path::new(".").join(sanitize_filename(&filename_hint));
std::fs::write(&target_path, &bytes).map_err(|e| {
ApiError::SlackError(format!(
"Failed to write file to {}: {}",
target_path.display(),
e
))
})?;
target_path.display().to_string()
}
};
Ok(json!({
"ok": true,
"output": output_path,
"size": bytes.len(),
"url": download_url
}))
}
fn sanitize_filename(name: &str) -> String {
let invalid_chars = ['/', '\\', ':', '*', '?', '"', '<', '>', '|', '\0'];
let sanitized: String = name
.chars()
.map(|c| if invalid_chars.contains(&c) { '_' } else { c })
.collect();
if sanitized.is_empty() {
"file".to_string()
} else {
sanitized
}
}
fn truncate_safely(s: &str, max_len: usize) -> String {
if s.len() <= max_len {
s.to_string()
} else {
let truncated = s.chars().take(max_len).collect::<String>();
format!("{}...", truncated)
}
}
#[cfg(test)]
mod tests {
use super::*;
use serial_test::serial;
#[tokio::test]
#[serial(write_guard)]
async fn test_file_upload_write_not_allowed() {
std::env::set_var("SLACKCLI_ALLOW_WRITE", "false");
let client = ApiClient::with_token("test_token".to_string());
let result = file_upload(
&client,
"/tmp/test.txt".to_string(),
None,
None,
None,
true,
false,
)
.await;
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), ApiError::WriteNotAllowed));
std::env::remove_var("SLACKCLI_ALLOW_WRITE");
}
#[tokio::test]
#[serial(write_guard)]
async fn test_file_upload_nonexistent_file() {
std::env::remove_var("SLACKCLI_ALLOW_WRITE");
let client = ApiClient::with_token("test_token".to_string());
let result = file_upload(
&client,
"/nonexistent/file.txt".to_string(),
None,
None,
None,
true,
false,
)
.await;
assert!(result.is_err());
if let Err(ApiError::SlackError(msg)) = result {
assert!(msg.contains("File not found"));
} else {
panic!("Expected SlackError with 'File not found'");
}
}
#[test]
fn test_sanitize_filename() {
assert_eq!(sanitize_filename("test.txt"), "test.txt");
assert_eq!(sanitize_filename("test/file.txt"), "test_file.txt");
assert_eq!(sanitize_filename("test:file.txt"), "test_file.txt");
assert_eq!(sanitize_filename("test*file?.txt"), "test_file_.txt");
assert_eq!(sanitize_filename(""), "file");
}
#[test]
fn test_truncate_safely() {
assert_eq!(truncate_safely("short", 100), "short");
assert_eq!(truncate_safely("exact", 5), "exact");
let long_str = "This is a very long string that needs to be truncated";
let truncated = truncate_safely(long_str, 20);
assert_eq!(truncated, "This is a very long ...");
assert!(truncated.len() <= 23);
assert_eq!(truncate_safely("", 10), "");
let unicode = "日本語テキスト";
let result = truncate_safely(unicode, 3);
assert!(result.starts_with("日本語"));
}
#[tokio::test]
#[serial(write_guard)]
async fn test_file_download_write_allowed() {
std::env::set_var("SLACKCLI_ALLOW_WRITE", "false");
let client = ApiClient::with_token("test_token".to_string());
let result = file_download(
&client,
Some("F123456".to_string()),
None,
Some("/tmp/test_download.txt".to_string()),
)
.await;
std::env::remove_var("SLACKCLI_ALLOW_WRITE");
assert!(result.is_err());
if let Err(e) = result {
assert!(!matches!(e, ApiError::WriteNotAllowed));
}
}
}