use super::{Label, Suggest};
use crate::{Result, client::Client, constants, error::Error};
use reqwest::header::{CONTENT_TYPE, HeaderMap, HeaderValue};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use tracing::debug;
#[derive(Debug, Serialize, Clone, Copy, PartialEq)]
pub enum Scene {
Profile = 1,
Comment = 2,
Forum = 3,
SocialLog = 4,
}
#[derive(Debug, Serialize, Clone)]
pub struct Args {
pub content: String,
pub version: u32,
pub scene: Scene,
pub openid: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub title: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub nickname: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub signature: Option<String>,
}
#[derive(Debug, Default)]
pub struct ArgsBuilder {
content: Option<String>,
version: Option<u32>,
scene: Option<Scene>,
openid: Option<String>,
title: Option<String>,
nickname: Option<String>,
signature: Option<String>,
}
impl ArgsBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn content(mut self, content: impl Into<String>) -> Self {
self.content = Some(content.into());
self
}
pub fn version(mut self, version: u32) -> Self {
self.version = Some(version);
self
}
pub fn scene(mut self, scene: Scene) -> Self {
self.scene = Some(scene);
self
}
pub fn openid(mut self, openid: impl Into<String>) -> Self {
self.openid = Some(openid.into());
self
}
pub fn title(mut self, title: impl Into<String>) -> Self {
self.title = Some(title.into());
self
}
pub fn nickname(mut self, nickname: impl Into<String>) -> Self {
self.nickname = Some(nickname.into());
self
}
pub fn signature(mut self, signature: impl Into<String>) -> Self {
self.signature = Some(signature.into());
self
}
pub fn build(self) -> Result<Args> {
let content = self
.content
.ok_or(Error::InvalidParameter("content 是必填参数".to_string()))?;
let version = self.version.unwrap_or(2); let scene = self
.scene
.ok_or(Error::InvalidParameter("scene 是必填参数".to_string()))?;
let openid = self
.openid
.ok_or(Error::InvalidParameter("openid 是必填参数".to_string()))?;
if content.len() > 2500 {
return Err(Error::InvalidParameter(
"content 长度不能超过2500字".to_string(),
));
}
if self.signature.is_some() && scene != Scene::Profile {
return Err(Error::InvalidParameter(
"signature 仅在资料场景(scene=1)下有效".to_string(),
));
}
Ok(Args {
content,
version,
scene,
openid,
title: self.title,
nickname: self.nickname,
signature: self.signature,
})
}
}
impl Args {
pub fn builder() -> ArgsBuilder {
ArgsBuilder::new()
}
pub fn new(content: impl Into<String>, scene: Scene, openid: impl Into<String>) -> Self {
Self {
content: content.into(),
version: 2,
scene,
openid: openid.into(),
title: None,
nickname: None,
signature: None,
}
}
pub fn is_profile_scene(&self) -> bool {
self.scene == Scene::Profile
}
pub fn content_length(&self) -> usize {
self.content.len()
}
pub fn validate(&self) -> Result<()> {
if self.content.len() > 2500 {
return Err(Error::InvalidParameter(
"content 长度不能超过2500字".to_string(),
));
}
if self.signature.is_some() && !self.is_profile_scene() {
return Err(Error::InvalidParameter(
"signature 仅在资料场景(scene=1)下有效".to_string(),
));
}
Ok(())
}
}
impl Scene {
pub fn from_value(value: u32) -> Option<Self> {
match value {
1 => Some(Scene::Profile),
2 => Some(Scene::Comment),
3 => Some(Scene::Forum),
4 => Some(Scene::SocialLog),
_ => None,
}
}
pub fn description(&self) -> &'static str {
match self {
Scene::Profile => "资料",
Scene::Comment => "评论",
Scene::Forum => "论坛",
Scene::SocialLog => "社交日志",
}
}
}
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct DetailResult {
pub strategy: String,
pub errcode: i32,
#[serde(skip_serializing_if = "Option::is_none")]
pub suggest: Option<Suggest>,
#[serde(skip_serializing_if = "Option::is_none")]
pub label: Option<Label>,
#[serde(skip_serializing_if = "Option::is_none")]
pub keyword: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub prob: Option<f64>,
}
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct ComprehensiveResult {
pub suggest: Suggest,
pub label: Label,
}
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct MsgSecCheckResult {
pub errcode: i32,
pub errmsg: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub detail: Option<Vec<DetailResult>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub result: Option<ComprehensiveResult>,
#[serde(skip_serializing_if = "Option::is_none")]
pub trace_id: Option<String>,
}
impl MsgSecCheckResult {
pub fn is_success(&self) -> bool {
self.errcode == 0
}
pub fn get_suggest(&self) -> Option<&Suggest> {
self.result.as_ref().map(|r| &r.suggest)
}
pub fn get_label(&self) -> Option<&Label> {
self.result.as_ref().map(|r| &r.label)
}
pub fn is_pass(&self) -> bool {
self.get_suggest().map(|s| s.is_pass()).unwrap_or(false)
}
pub fn is_risky(&self) -> bool {
self.get_suggest().map(|s| s.is_risky()).unwrap_or(false)
}
pub fn needs_review(&self) -> bool {
self.get_suggest()
.map(|s| s.needs_review())
.unwrap_or(false)
}
pub fn get_valid_details(&self) -> Vec<&DetailResult> {
self.detail
.as_ref()
.map(|details| details.iter().filter(|d| d.errcode == 0).collect())
.unwrap_or_default()
}
}
impl Client {
pub async fn msg_sec_check(&self, args: &Args) -> Result<MsgSecCheckResult> {
debug!("msg_sec_check args: {:?}", &args);
args.validate()?;
let access_token = self.access_token().await?;
let mut query = HashMap::new();
let mut body = HashMap::new();
let version = args.version.to_string();
let scene = (args.scene as u32).to_string();
query.insert("access_token", &access_token);
body.insert("content", &args.content);
body.insert("version", &version);
body.insert("scene", &scene);
body.insert("openid", &args.openid);
if let Some(title) = &args.title {
body.insert("title", title);
}
if let Some(nickname) = &args.nickname {
body.insert("nickname", nickname);
}
if let Some(signature) = &args.signature {
body.insert("signature", signature);
}
let mut headers = HeaderMap::new();
headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
let response = self
.request()
.post(constants::MSG_SEC_CHECK_END_POINT)
.headers(headers)
.query(&query)
.json(&body)
.send()
.await?;
debug!("msg_sec_check response: {:#?}", response);
if response.status().is_success() {
let response_text = response.text().await?;
debug!("msg_sec_check response body: {}", response_text);
let result: MsgSecCheckResult = serde_json::from_str(&response_text)?;
if result.is_success() {
Ok(result)
} else {
Err(Error::InternalServer(format!(
"微信内容安全检测API错误: {} - {}",
result.errcode, result.errmsg
)))
}
} else {
Err(Error::InternalServer(response.text().await?))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_args_builder() {
let args = Args::builder()
.content("测试内容")
.scene(Scene::Comment)
.openid("test_openid")
.build()
.unwrap();
assert_eq!(args.content, "测试内容");
assert_eq!(args.version, 2);
assert_eq!(args.scene, Scene::Comment);
assert_eq!(args.openid, "test_openid");
}
#[test]
fn test_args_builder_validation() {
let result = Args::builder()
.scene(Scene::Comment)
.openid("test_openid")
.build();
assert!(result.is_err());
let long_content = "a".repeat(2501);
let result = Args::builder()
.content(long_content)
.scene(Scene::Comment)
.openid("openid")
.build();
assert!(result.is_err());
let result = Args::builder()
.content("内容")
.scene(Scene::Comment)
.openid("openid")
.signature("签名")
.build();
assert!(result.is_err());
}
#[test]
fn test_scene_enum() {
assert_eq!(Scene::from_value(1), Some(Scene::Profile));
assert_eq!(Scene::Profile.description(), "资料");
assert_eq!(Scene::Profile as u32, 1);
}
#[test]
fn test_msg_sec_check_result() {
let json = r#"
{
"errcode": 0,
"errmsg": "ok",
"detail": [
{
"strategy": "content_model",
"errcode": 0,
"suggest": "pass",
"label": 100,
"prob": 90.5
}
],
"result": {
"suggest": "pass",
"label": 100
},
"trace_id": "test_trace_id"
}"#;
let result: MsgSecCheckResult = serde_json::from_str(json).unwrap();
assert!(result.is_success());
assert!(result.is_pass());
assert!(!result.is_risky());
assert!(!result.needs_review());
assert_eq!(result.get_valid_details().len(), 1);
assert_eq!(result.trace_id, Some("test_trace_id".to_string()));
}
#[test]
fn test_msg_sec_check_result_with_risk() {
let json = r#"
{
"errcode": 0,
"errmsg": "ok",
"detail": [
{
"strategy": "content_model",
"errcode": 0,
"suggest": "risky",
"label": 20001,
"keyword": "敏感词",
"prob": 95.0
}
],
"result": {
"suggest": "risky",
"label": 20001
}
}"#;
let result: MsgSecCheckResult = serde_json::from_str(json).unwrap();
assert!(result.is_success());
assert!(!result.is_pass());
assert!(result.is_risky());
assert!(!result.needs_review());
assert_eq!(
result.get_valid_details()[0].keyword,
Some("敏感词".to_string())
);
}
}