stock-rust 0.1.0

Rust version of stock-api
Documentation
use std::sync::{Arc, Mutex};

use anyhow::{anyhow, Result};
use async_trait::async_trait;
use reqwest::header::SET_COOKIE;
use serde_json::Value;

use crate::stocks::base::{
    as_f64, as_string, dedup_codes, percent,
};
use crate::stocks::transforms::base::CommonCodeTransform;
use crate::stocks::transforms::xueqiu::{XueqiuCommonCodeTransform, XueqiuSearchCodeTransform};
use crate::types::{Stock, StockApi};

#[derive(Clone)]
pub struct XueqiuApi {
    client: reqwest::Client,
    token: Arc<Mutex<Option<String>>>,
}

impl XueqiuApi {
    pub fn new(client: reqwest::Client) -> Self {
        Self {
            client,
            token: Arc::new(Mutex::new(None)),
        }
    }

    async fn get_token(&self) -> Result<String> {
        if let Some(token) = self.token.lock().map_err(|_| anyhow!("token lock error"))?.clone() {
            return Ok(token);
        }

        let res = self.client.get("https://xueqiu.com/").send().await?;
        let mut token = String::new();

        for value in &res.headers().get_all(SET_COOKIE) {
            let raw = value.to_str().unwrap_or_default();
            if raw.contains("xq_a_token") {
                token = raw.split(';').next().unwrap_or_default().to_string();
                break;
            }
        }

        if token.is_empty() {
            return Err(anyhow!("雪球 token 获取失败"));
        }

        *self.token.lock().map_err(|_| anyhow!("token lock error"))? = Some(token.clone());
        Ok(token)
    }

    fn stock_from_quote(code: &str, quote: &Value) -> Stock {
        let name = as_string(quote.get("name"));
        let now = as_f64(quote.get("current"));
        let low = as_f64(quote.get("low"));
        let high = as_f64(quote.get("high"));
        let yesterday = as_f64(quote.get("last_close"));

        Stock {
            code: code.to_uppercase(),
            name,
            now,
            low,
            high,
            yesterday,
            percent: percent(now, yesterday),
        }
    }
}

#[async_trait]
impl StockApi for XueqiuApi {
    async fn get_stock(&self, code: &str) -> Result<Stock> {
        let token = self.get_token().await?;
        let transformer = XueqiuCommonCodeTransform;
        let transformed = transformer.transform(code)?;
        let url = format!("https://stock.xueqiu.com/v5/stock/quote.json?symbol={}", transformed);

        let root: Value = self
            .client
            .get(url)
            .header("Cookie", token)
            .send()
            .await?
            .json()
            .await?;

        let quote = root.get("data").and_then(|d| d.get("quote"));
        match quote {
            Some(v) if !v.is_null() => Ok(Self::stock_from_quote(code, v)),
            _ => Ok(Stock::default_with_code(code)),
        }
    }

    async fn get_stocks(&self, codes: &[String]) -> Result<Vec<Stock>> {
        let codes = dedup_codes(codes);
        if codes.is_empty() {
            return Ok(Vec::new());
        }

        let token = self.get_token().await?;
        let transformer = XueqiuCommonCodeTransform;
        let transformed = transformer.transforms(&codes)?;

        let url = format!(
            "https://stock.xueqiu.com/v5/stock/batch/quote.json?symbol={}",
            transformed.join(",")
        );

        let root: Value = self
            .client
            .get(url)
            .header("Cookie", token)
            .send()
            .await?
            .json()
            .await?;

        let rows = root
            .get("data")
            .and_then(|d| d.get("items"))
            .and_then(Value::as_array)
            .cloned()
            .unwrap_or_default();

        let data = codes
            .iter()
            .map(|code| {
                let transformed = transformer.transform(code).unwrap_or_default();
                let params = rows.iter().find(|item| {
                    let region = item
                        .get("market")
                        .and_then(|m| m.get("region"))
                        .and_then(Value::as_str)
                        .unwrap_or_default();
                    let quote = item.get("quote").unwrap_or(&Value::Null);
                    if quote.is_null() {
                        return false;
                    }

                    match region {
                        "US" => quote
                            .get("code")
                            .and_then(Value::as_str)
                            .map(|v| v == transformed)
                            .unwrap_or(false),
                        "CN" => quote
                            .get("symbol")
                            .and_then(Value::as_str)
                            .map(|v| v == transformed)
                            .unwrap_or(false),
                        "HK" => quote
                            .get("code")
                            .and_then(Value::as_str)
                            .map(|v| format!("HK{}", v) == transformed)
                            .unwrap_or(false),
                        _ => false,
                    }
                });

                let quote = params.and_then(|i| i.get("quote"));
                match quote {
                    Some(v) if !v.is_null() => Self::stock_from_quote(code, v),
                    _ => Stock::default_with_code(code),
                }
            })
            .collect();

        Ok(data)
    }

    async fn search_stocks(&self, key: &str) -> Result<Vec<Stock>> {
        let token = self.get_token().await?;
        let url = format!(
            "https://xueqiu.com/stock/search.json?code={}",
            urlencoding::encode(key)
        );

        let root: Value = self
            .client
            .get(url)
            .header("Cookie", token)
            .send()
            .await?
            .json()
            .await?;

        let rows = root
            .get("stocks")
            .and_then(Value::as_array)
            .cloned()
            .unwrap_or_default();

        let search_transformer = XueqiuSearchCodeTransform;
        let mut codes: Vec<String> = Vec::new();
        for row in rows {
            let code = row
                .get("code")
                .and_then(Value::as_str)
                .unwrap_or_default()
                .to_string();

            codes.extend(search_transformer.transform(&code));
        }

        self.get_stocks(&dedup_codes(&codes)).await
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::stocks::transforms::base::CommonCodeTransform;
    use crate::stocks::transforms::xueqiu::XueqiuCommonCodeTransform;

    #[test]
    fn test_to_api_code() {
        let transformer = XueqiuCommonCodeTransform;
        assert_eq!(transformer.transform("SH510500").unwrap(), "SH510500");
        assert_eq!(transformer.transform("SZ399001").unwrap(), "SZ399001");
        assert_eq!(transformer.transform("HKHSI").unwrap(), "HKHSI");
        assert_eq!(transformer.transform("USDJI").unwrap(), "DJI");
    }

    #[test]
    fn test_stock_from_quote() {
        let quote: Value = serde_json::json!({
            "name": "中证500ETF",
            "current": 7.224,
            "low": 7.085,
            "high": 7.28,
            "last_close": 7.149
        });
        let s = XueqiuApi::stock_from_quote("SH510500", &quote);
        assert_eq!(s.code, "SH510500");
        assert_eq!(s.name, "中证500ETF");
        assert!((s.now - 7.224).abs() < 1e-12);
        assert!((s.low - 7.085).abs() < 1e-12);
        assert!((s.high - 7.28).abs() < 1e-12);
        assert!((s.yesterday - 7.149).abs() < 1e-12);
    }
}