use clap::Args;
use color_eyre::Result;
use color_eyre::eyre::{OptionExt, eyre};
use convertor::config::ConvertorConfig;
use convertor::config::client_config::ProxyClient;
use convertor::config::provider_config::Provider;
use convertor::core::profile::Profile;
use convertor::core::profile::clash_profile::ClashProfile;
use convertor::core::profile::extract_policies_for_rule_provider;
use convertor::core::profile::policy::Policy;
use convertor::core::profile::surge_profile::SurgeProfile;
use convertor::error::UrlBuilderError;
use convertor::provider_api::ProviderApi;
use convertor::url::convertor_url::ConvertorUrlType;
use convertor::url::query::ConvertorQuery;
use convertor::url::url_builder::{HostPort, UrlBuilder};
use convertor::url::url_result::UrlResult;
use headers::UserAgent;
use std::collections::HashMap;
use url::Url;
#[derive(Default, Debug, Clone, Hash, Args)]
pub struct ProviderCmd {
#[arg(value_enum)]
pub client: ProxyClient,
#[arg(value_enum, default_value_t = Provider::BosLife)]
pub provider: Provider,
#[arg()]
pub url: Option<Url>,
#[arg(short, long)]
pub server: Option<Url>,
#[arg(short, long)]
pub interval: Option<u64>,
#[arg(short = 'S', long)]
pub strict: Option<bool>,
#[arg(short, long, default_value_t = false)]
pub reset: bool,
#[cfg(feature = "update")]
#[arg(short, long, default_value_t = false)]
pub update: bool,
}
pub struct ProviderCli {
pub config: ConvertorConfig,
pub api_map: HashMap<Provider, ProviderApi>,
}
#[allow(clippy::large_enum_variant)]
enum ClientProfile {
Surge,
#[cfg(feature = "update")]
Clash(ClashProfile),
#[cfg(not(feature = "update"))]
Clash,
}
impl ProviderCli {
pub fn new(config: ConvertorConfig, api_map: HashMap<Provider, ProviderApi>) -> Self {
Self { config, api_map }
}
pub async fn execute(&mut self, cmd: ProviderCmd) -> Result<(UrlBuilder, UrlResult)> {
let client = cmd.client;
let provider = cmd.provider;
let url_builder = self.create_url_builder(&cmd).await?;
let api = self
.api_map
.get_mut(&provider)
.ok_or(eyre!("无法取得订阅供应商的 api 实现: {}", &provider))?;
api.set_sub_url(url_builder.sub_url.clone());
let raw_profile_content = api
.get_raw_profile(client, UserAgent::from_static("Surge Mac/8310"))
.await?;
let sub_host = url_builder
.sub_url
.host_port()
.ok_or_eyre("无法从 sub_url 中提取 host port")?;
let (_client_profile, policies) = match client {
ProxyClient::Surge => {
let mut raw_profile = SurgeProfile::parse(raw_profile_content)?;
raw_profile.convert(&url_builder)?;
let mut policies: Vec<Policy> = raw_profile.policy_of_rules.keys().cloned().collect();
policies.sort();
(ClientProfile::Surge, policies)
}
ProxyClient::Clash => {
let raw_profile = ClashProfile::parse(raw_profile_content)?;
let policies = extract_policies_for_rule_provider(&raw_profile.rules, sub_host);
#[cfg(feature = "update")]
{
(ClientProfile::Clash(raw_profile), policies)
}
#[cfg(not(feature = "update"))]
{
(ClientProfile::Clash, policies)
}
}
};
let raw_url = url_builder.build_raw_url();
let profile_url = url_builder.build_profile_url()?;
let raw_profile_url = url_builder.build_raw_profile_url()?;
let sub_logs_url = url_builder.build_sub_logs_url()?;
let rule_provider_urls = policies
.iter()
.map(|policy| url_builder.build_rule_provider_url(policy))
.collect::<Result<Vec<_>, UrlBuilderError>>()?;
let result = UrlResult {
raw_url,
raw_profile_url,
profile_url,
sub_logs_url,
rule_providers_url: rule_provider_urls,
};
#[cfg(feature = "update")]
if cmd.update {
match _client_profile {
ClientProfile::Surge => {
super::update::update_surge_config(&self.config, &url_builder, &policies).await?;
}
ClientProfile::Clash(profile) => {
super::update::update_clash_config(&self.config, &url_builder, profile).await?;
}
}
}
Ok((url_builder, result))
}
pub fn post_execute(&self, _url_builder: UrlBuilder, result: UrlResult) {
println!("{result}");
}
async fn create_url_builder(&self, cmd: &ProviderCmd) -> Result<UrlBuilder> {
let ProviderCmd {
client,
provider,
url,
server,
interval,
strict,
reset,
#[cfg(feature = "update")]
update: _,
} = cmd;
let client_config = self
.config
.clients
.get(client)
.ok_or(eyre!("无法取得代理客户端的配置: {}", client))?;
let server = server.clone().unwrap_or_else(|| self.config.server.clone());
let mut enc_secret = None;
let mut enc_sub_url = None;
let mut interval = interval
.as_ref()
.map(|i| *i)
.unwrap_or_else(|| client_config.interval());
let mut strict = strict.as_ref().map(|s| *s).unwrap_or_else(|| client_config.strict());
let url_type = self.detect_url(cmd);
let sub_url = match (url_type, reset) {
(None, false) => {
self.api_map
.get(provider)
.ok_or(eyre!("无法取得订阅供应商的 api 实现: {}", &provider))?
.get_sub_url()
.await?
}
(None, true) => {
self.api_map
.get(provider)
.ok_or(eyre!("无法取得订阅供应商的 api 实现: {}", &provider))?
.reset_sub_url()
.await?
}
(Some(ConvertorUrlType::Raw), _) => url.clone().unwrap(),
(Some(ConvertorUrlType::Profile), _) => {
let profile_query = url.as_ref().and_then(Url::query).ok_or(eyre!("订阅链接缺少查询参数"))?;
let query =
ConvertorQuery::parse_from_query_string(profile_query, &self.config.secret, server.clone())?;
enc_secret = query.enc_secret.clone();
enc_sub_url = Some(query.enc_sub_url.clone());
interval = query.interval;
strict = query.strict.unwrap_or(strict);
query.sub_url.clone()
}
_ => unreachable!("不支持的订阅链接类型"),
};
let url_builder = UrlBuilder::new(
self.config.secret.clone(),
enc_secret,
*client,
*provider,
server,
sub_url,
enc_sub_url,
interval,
strict,
)?;
Ok(url_builder)
}
fn detect_url(&self, cmd: &ProviderCmd) -> Option<ConvertorUrlType> {
let ProviderCmd { url, server, .. } = cmd;
let server = server
.as_ref()
.map(|s| s.to_string())
.unwrap_or_else(|| self.config.server.to_string());
url.as_ref().map(|url| {
if url.as_str().starts_with(&server) {
ConvertorUrlType::Profile
} else {
ConvertorUrlType::Raw
}
})
}
}
impl ProviderCmd {
pub fn snapshot_name(&self) -> String {
let client = self.client.to_string();
let provider = self.provider.to_string();
let url = self
.url
.as_ref()
.map_or("no_url".to_string(), |url| url.host_port().unwrap());
let server = self
.server
.as_ref()
.map_or("no_server".to_string(), |server| server.to_string());
let interval = self
.interval
.map_or("no_interval".to_string(), |interval| interval.to_string());
let strict = self.strict.map_or("no_strict".to_string(), |_| "strict".to_string());
let reset = if self.reset { "reset" } else { "no_reset" };
#[cfg(feature = "update")]
let update = if self.update { "update" } else { "no_update" };
#[cfg(not(feature = "update"))]
let update = "no_update";
format!("{client}-{provider}-{url}-{server}-{interval}-{strict}-{reset}-{update}")
}
}