use crate::inner::{create_client, decode, uncompress};
use crate::{VkApiError, VkApiResult};
use bytes::Buf;
use cfg_if::cfg_if;
use reqwest::header::{ACCEPT, ACCEPT_ENCODING, CONTENT_ENCODING, CONTENT_TYPE};
use reqwest::Client;
use serde::de::DeserializeOwned;
use serde::{Deserialize, Deserializer, Serialize};
use std::error::Error;
use std::fmt::{Display, Formatter};
#[derive(Debug, Clone)]
pub struct VkLongPoll {
client: Client,
}
impl VkLongPoll {
#[cfg(feature = "longpoll_stream")]
pub fn subscribe<T: Serialize + Clone + Send, I: DeserializeOwned>(
&self,
mut request: LongPollRequest<T>,
) -> impl futures_util::Stream<Item = VkApiResult<I>> {
let client = self.client.clone();
async_stream::stream! {
loop {
match Self::subscribe_once_with_client(&client, request.clone()).await {
Err(VkApiError::LongPoll(LongPollError { ts: Some(ts), .. })) => {
request.ts = ts;
},
Ok(LongPollSuccess{ ts, updates }) => {
request.ts = ts.clone();
for update in updates {
yield Ok(update);
}
},
Err(e) => {
yield Err(e);
break;
},
};
}
}
}
pub async fn subscribe_once<T: Serialize + Send, I: DeserializeOwned>(
&self,
request: LongPollRequest<T>,
) -> VkApiResult<LongPollSuccess<I>> {
Self::subscribe_once_with_client(&self.client, request).await
}
async fn subscribe_once_with_client<T: Serialize + Send, I: DeserializeOwned>(
client: &Client,
request: LongPollRequest<T>,
) -> VkApiResult<LongPollSuccess<I>> {
let LongPollInnerRequest(LongPollServer(server), params) =
LongPollInnerRequest::from(request);
let params = serde_urlencoded::to_string(params).map_err(VkApiError::RequestSerialize)?;
let url = if server.starts_with("http") {
format!("{server}?act=a_check&{params}")
} else {
format!("https://{server}?act=a_check&{params}")
};
cfg_if! {
if #[cfg(feature = "compression_gzip")] {
let encoding = "gzip";
} else {
let encoding = "identity";
}
}
cfg_if! {
if #[cfg(feature = "encode_json")] {
let serialisation = "application/json";
} else {
let serialisation = "text/*";
}
}
let request = client
.get(url)
.header(ACCEPT_ENCODING, encoding)
.header(ACCEPT, serialisation);
let response = request.send().await.map_err(VkApiError::Request)?;
let headers = response.headers();
let content_type = headers.get(CONTENT_TYPE).cloned();
let content_encoding = headers.get(CONTENT_ENCODING).cloned();
let body = response.bytes().await.map_err(VkApiError::Request)?;
let resp = decode::<LongPollResponse<I>, _>(
content_type,
uncompress(content_encoding, body.reader())?,
)?;
match resp {
LongPollResponse::Success(r) => Ok(r),
LongPollResponse::Error(e) => Err(VkApiError::LongPoll(e)),
}
}
}
impl From<Client> for VkLongPoll {
fn from(client: Client) -> Self {
Self { client }
}
}
impl Default for VkLongPoll {
fn default() -> Self {
Self::from(create_client())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
enum LongPollResponse<R> {
Success(LongPollSuccess<R>),
Error(LongPollError),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LongPollSuccess<R> {
#[serde(deserialize_with = "deserialize_usize_or_string")]
ts: String,
updates: Vec<R>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LongPollError {
failed: usize,
#[serde(default)]
#[serde(deserialize_with = "deserialize_usize_or_string_option")]
ts: Option<String>,
#[serde(default)]
min_version: Option<usize>,
#[serde(default)]
max_version: Option<usize>,
}
impl Display for LongPollError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "long poll error occurred, code: {}", self.failed,)
}
}
impl Error for LongPollError {}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LongPollRequest<T> {
pub server: String,
pub key: String,
#[serde(deserialize_with = "deserialize_usize_or_string")]
pub ts: String,
pub wait: usize,
#[serde(flatten)]
pub additional_params: T,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct LongPollServer(String);
#[derive(Debug, Clone, Serialize, Deserialize)]
struct LongPollQueryParams<T> {
key: String,
#[serde(deserialize_with = "deserialize_usize_or_string")]
ts: String,
wait: usize,
#[serde(flatten)]
additional_params: T,
}
struct LongPollInnerRequest<T>(LongPollServer, LongPollQueryParams<T>);
impl<T> From<LongPollRequest<T>> for LongPollInnerRequest<T> {
fn from(
LongPollRequest {
server,
key,
ts,
wait,
additional_params,
}: LongPollRequest<T>,
) -> Self {
Self(
LongPollServer(server),
LongPollQueryParams {
key,
ts,
wait,
additional_params,
},
)
}
}
struct DeserializeUsizeOrString;
impl<'de> serde::de::Visitor<'de> for DeserializeUsizeOrString {
type Value = String;
fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result {
formatter.write_str("an integer or a string")
}
fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(v.to_string())
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(v.to_owned())
}
fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(v)
}
}
struct DeserializeUsizeOrStringOption;
impl<'de> serde::de::Visitor<'de> for DeserializeUsizeOrStringOption {
type Value = Option<String>;
fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result {
formatter.write_str("an integer or a string or a null")
}
fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(Some(v.to_string()))
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(Some(v.to_owned()))
}
fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(Some(v))
}
fn visit_none<E>(self) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(None)
}
}
fn deserialize_usize_or_string<'de, D>(
deserializer: D,
) -> Result<String, <D as Deserializer<'de>>::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_any(DeserializeUsizeOrString)
}
fn deserialize_usize_or_string_option<'de, D>(
deserializer: D,
) -> Result<Option<String>, <D as Deserializer<'de>>::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_any(DeserializeUsizeOrStringOption)
}
#[cfg(test)]
mod tests {
use crate::longpoll::{deserialize_usize_or_string, deserialize_usize_or_string_option};
use serde::Deserialize;
#[derive(Deserialize)]
struct Ts {
#[serde(deserialize_with = "deserialize_usize_or_string")]
ts: String,
}
#[derive(Deserialize)]
struct TsOpt {
#[serde(default)]
#[serde(deserialize_with = "deserialize_usize_or_string_option")]
ts: Option<String>,
}
#[test]
fn test_deserialize_ts_string() {
let ts: Ts = serde_json::from_str(r#"{"ts": "123"}"#).unwrap();
assert_eq!(ts.ts, "123".to_owned());
}
#[test]
fn test_deserialize_ts_usize() {
let ts: Ts = serde_json::from_str(r#"{"ts": 123}"#).unwrap();
assert_eq!(ts.ts, "123".to_owned());
}
#[test]
fn test_deserialize_ts_opt_string() {
let ts: TsOpt = serde_json::from_str(r#"{"ts": "123"}"#).unwrap();
assert_eq!(ts.ts, Some("123".to_owned()));
}
#[test]
fn test_deserialize_ts_opt_usize() {
let ts: TsOpt = serde_json::from_str(r#"{"ts": 123}"#).unwrap();
assert_eq!(ts.ts, Some("123".to_owned()));
}
#[test]
fn test_deserialize_ts_opt_none() {
let ts: TsOpt = serde_json::from_str("{}").unwrap();
assert_eq!(ts.ts, None);
}
}