rustymilky 0.1.0

Milky 协议的 Rust SDK
Documentation
use std::time::Duration;

use anyhow::{Context, anyhow};

use crate::{
  Result,
  protocol::{ApiEndpoint, ApiGeneralResponse},
  retry::{RetryOptions, RetryState, is_retryable_request_error},
  transport::{MilkyTransport, MilkyTransportKind, spawn_sse_transport, spawn_websocket_transport},
  url::{build_api_url, build_event_url, websocket_request},
};

#[derive(Debug, Clone, PartialEq)]
/// [`MilkyClient`] 客户度配置。
pub struct MilkyClientConfig {
  /// 事件流断开后的重连策略。
  pub event_reconnect: RetryOptions,
  /// 普通 API 请求遇到可恢复网络错误时的重试策略。
  pub request_retry: RetryOptions,
}

impl Default for MilkyClientConfig {
  fn default() -> Self {
    Self {
      event_reconnect: RetryOptions::Exponential {
        initial_delay: Duration::from_millis(300),
        factor: 2.0,
        max_delay: Some(Duration::from_secs(5)),
        max_retries: None,
      },
      request_retry: RetryOptions::Disabled,
    }
  }
}

#[derive(Debug, Clone)]
/// Milky 异步客户端。
///
/// 该类型负责发送 API 请求,并可建立 SSE 或 WebSocket 事件流。
pub struct MilkyClient {
  base_url: reqwest::Url,
  client: reqwest::Client,
  authorization: Option<reqwest::header::HeaderValue>,
  config: MilkyClientConfig,
}

impl MilkyClient {
  /// 使用默认配置创建客户端。
  ///
  /// `token` 会被作为 Bearer Token 写入 `Authorization` 请求头。
  pub fn new(base_url: &str, token: Option<&str>) -> Result<Self> {
    Self::with_config(base_url, token, MilkyClientConfig::default())
  }

  /// 使用显式配置创建客户端。
  ///
  /// `base_url` 应指向 Milky API 的根地址,例如 `http://127.0.0.1:3100`。
  pub fn with_config(
    base_url: &str,
    token: Option<&str>,
    config: MilkyClientConfig,
  ) -> Result<Self> {
    let base_url = reqwest::Url::parse(base_url).context("Failed to parse base URL")?;

    let mut client = reqwest::Client::builder();
    let authorization = token
      .map(|token| reqwest::header::HeaderValue::from_str(&format!("Bearer {token}")))
      .transpose()
      .context("Failed to parse authorization header")?;

    let mut headers = reqwest::header::HeaderMap::new();
    headers.insert(
      reqwest::header::CONTENT_TYPE,
      reqwest::header::HeaderValue::from_str("application/json")
        .context("Failed to parse content type header")?,
    );

    if let Some(authorization) = authorization.as_ref() {
      headers.insert(reqwest::header::AUTHORIZATION, authorization.clone());
    }

    client = client.default_headers(headers);
    let client = client.build().context("Failed to build client")?;

    Ok(Self {
      base_url,
      client,
      authorization,
      config,
    })
  }

  /// 发送一次 API 请求,并使用客户端默认的请求重试策略。
  ///
  /// 该方法会将返回的通用响应包解包为具体输出类型;当服务端返回非零
  /// `retcode` 时,会转换为错误返回。
  pub async fn request<T>(&self, input: T) -> Result<T::Output>
  where
    T: ApiEndpoint,
  {
    self.send_request_with_retry(input, None).await
  }

  /// 发送一次 API 请求,并允许为当前调用覆盖默认重试策略。
  ///
  /// 当 `retry` 为 `None` 时,会回退到客户端设定的默认重试策略。
  pub async fn send_request_with_retry<T>(
    &self,
    input: T,
    retry: Option<RetryOptions>,
  ) -> Result<T::Output>
  where
    T: ApiEndpoint,
  {
    let url = build_api_url::<T>(&self.base_url)?;
    let response = self
      .send_request_response_with_retry(
        url,
        &input,
        retry.unwrap_or_else(|| self.config.request_retry.clone()),
      )
      .await?;
    let output = response.json::<ApiGeneralResponse<T::Output>>().await?;

    if output.retcode != 0 {
      Err(anyhow!(
        "milky error: code={}, message={:?}",
        output.retcode,
        output.message
      ))
    } else {
      output.data.ok_or(anyhow!("milky error: data not found"))
    }
  }

  /// 建立一个基于 SSE 的事件流连接。
  pub async fn event_sse(&self) -> Result<MilkyTransport> {
    let url = build_event_url(&self.base_url)?;
    Ok(spawn_sse_transport(
      self.client.clone(),
      url,
      self.config.event_reconnect.clone(),
    ))
  }

  /// 建立一个基于 WebSocket 的事件流连接。
  pub async fn event_ws(&self) -> Result<MilkyTransport> {
    let event_url = build_event_url(&self.base_url)?;
    let request = websocket_request(&event_url, self.authorization.as_ref())?;

    Ok(spawn_websocket_transport(
      request,
      self.config.event_reconnect.clone(),
    ))
  }

  /// 按传输类型建立事件流连接。
  pub async fn event(&self, kind: MilkyTransportKind) -> Result<MilkyTransport> {
    match kind {
      MilkyTransportKind::WebSocket => self.event_ws().await,
      MilkyTransportKind::Sse => self.event_sse().await,
    }
  }

  async fn send_request_response_with_retry<T>(
    &self,
    url: reqwest::Url,
    input: &T,
    retry: RetryOptions,
  ) -> Result<reqwest::Response>
  where
    T: ApiEndpoint,
  {
    let mut retry_state = RetryState::new(retry);

    loop {
      match self.client.post(url.clone()).json(input).send().await {
        Ok(response) => return Ok(response),
        Err(error) if is_retryable_request_error(&error) => {
          let Some(decision) = retry_state.next() else {
            return Err(error.into());
          };

          tokio::time::sleep(decision.delay).await;
        }
        Err(error) => return Err(error.into()),
      }
    }
  }
}

#[cfg(test)]
mod tests {
  use super::*;

  #[test]
  fn default_client_config_matches_transport_defaults() {
    let config = MilkyClientConfig::default();

    assert_eq!(
      config.event_reconnect,
      RetryOptions::Exponential {
        initial_delay: Duration::from_millis(300),
        factor: 2.0,
        max_delay: Some(Duration::from_secs(5)),
        max_retries: None,
      }
    );
    assert_eq!(config.request_retry, RetryOptions::Disabled);
  }
}