use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use validator::Validate;
use super::create::{BackgroundColor, EmbeddingId, KnowledgeIcon};
use crate::client::http::HttpClient;
#[derive(Debug, Clone, Default, Serialize, Deserialize, Validate)]
pub struct UpdateKnowledgeBody {
#[serde(skip_serializing_if = "Option::is_none")]
pub embedding_id: Option<EmbeddingId>,
#[serde(skip_serializing_if = "Option::is_none")]
#[validate(length(min = 1))]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub background: Option<BackgroundColor>,
#[serde(skip_serializing_if = "Option::is_none")]
pub icon: Option<KnowledgeIcon>,
#[serde(skip_serializing_if = "Option::is_none")]
pub callback_url: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub callback_header: Option<HashMap<String, String>>,
}
impl UpdateKnowledgeBody {
fn is_empty(&self) -> bool {
self.embedding_id.is_none()
&& self.name.is_none()
&& self.description.is_none()
&& self.background.is_none()
&& self.icon.is_none()
&& self.callback_url.is_none()
&& self.callback_header.is_none()
}
}
pub struct KnowledgeUpdateRequest {
pub key: String,
url: String,
body: UpdateKnowledgeBody,
}
impl KnowledgeUpdateRequest {
pub fn new(key: String, id: impl AsRef<str>) -> Self {
let url = format!(
"https://open.bigmodel.cn/api/llm-application/open/knowledge/{}",
id.as_ref()
);
Self {
key,
url,
body: UpdateKnowledgeBody::default(),
}
}
pub fn with_embedding_id(mut self, id: EmbeddingId) -> Self {
self.body.embedding_id = Some(id);
self
}
pub fn with_name(mut self, name: impl Into<String>) -> Self {
self.body.name = Some(name.into());
self
}
pub fn with_description(mut self, desc: impl Into<String>) -> Self {
self.body.description = Some(desc.into());
self
}
pub fn with_background(mut self, bg: BackgroundColor) -> Self {
self.body.background = Some(bg);
self
}
pub fn with_icon(mut self, icon: KnowledgeIcon) -> Self {
self.body.icon = Some(icon);
self
}
pub fn with_callback_url(mut self, url: impl Into<String>) -> Self {
self.body.callback_url = Some(url.into());
self
}
pub fn with_callback_header(mut self, headers: HashMap<String, String>) -> Self {
self.body.callback_header = Some(headers);
self
}
pub async fn send(&self) -> crate::ZaiResult<KnowledgeUpdateResponse> {
if self.body.is_empty() {
return Err(crate::client::error::ZaiError::ApiError {
code: 1200,
message: "update body is empty; set at least one field".to_string(),
});
}
self.body.validate()?;
let resp = self.put().await?;
let parsed = resp.json::<KnowledgeUpdateResponse>().await?;
Ok(parsed)
}
pub fn put(
&self,
) -> impl std::future::Future<Output = crate::ZaiResult<reqwest::Response>> + Send {
let url = self.url.clone();
let key = self.key.clone();
let body = self.body.clone();
async move {
let body_str = serde_json::to_string(&body)?;
let resp = reqwest::Client::new()
.put(url)
.bearer_auth(key)
.header("Content-Type", "application/json")
.body(body_str)
.send()
.await?;
let status = resp.status();
if status.is_success() {
return Ok(resp);
}
let text = resp.text().await.unwrap_or_default();
#[derive(serde::Deserialize)]
struct ErrEnv {
error: ErrObj,
}
#[derive(serde::Deserialize)]
struct ErrObj {
_code: serde_json::Value,
message: String,
}
if let Ok(parsed) = serde_json::from_str::<ErrEnv>(&text) {
return Err(crate::client::error::ZaiError::from_api_response(
status.as_u16(),
0,
parsed.error.message,
));
}
Err(crate::client::error::ZaiError::from_api_response(
status.as_u16(),
0,
text,
))
}
}
}
impl HttpClient for KnowledgeUpdateRequest {
type Body = UpdateKnowledgeBody;
type ApiUrl = String;
type ApiKey = String;
fn api_url(&self) -> &Self::ApiUrl {
&self.url
}
fn api_key(&self) -> &Self::ApiKey {
&self.key
}
fn body(&self) -> &Self::Body {
&self.body
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
pub struct KnowledgeUpdateResponse {
#[serde(skip_serializing_if = "Option::is_none")]
pub code: Option<i64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub message: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub timestamp: Option<u64>,
}