use crate::config::ReplicateConfig;
use anyhow::anyhow;
use bytes::Bytes;
use eventsource_stream::{EventStream, Eventsource};
use futures_lite::StreamExt;
use serde_json::Value;
use crate::models::ModelClient;
use crate::{api_key, base_url};
#[derive(serde::Serialize, serde::Deserialize, Debug, Eq, PartialEq, Clone)]
#[serde(rename_all = "lowercase")]
pub enum PredictionStatus {
    Starting,
    Processing,
    Succeeded,
    Failed,
    Canceled,
}
#[derive(serde::Deserialize, Debug)]
pub struct PredictionUrls {
    pub cancel: String,
    pub get: String,
    pub stream: Option<String>,
}
#[derive(serde::Deserialize, Debug)]
pub struct Prediction {
    pub id: String,
    pub model: String,
    pub version: String,
    pub input: Value,
    pub status: PredictionStatus,
    pub created_at: String,
    pub urls: PredictionUrls,
    pub output: Option<Value>,
}
#[derive(serde::Deserialize, Debug)]
pub struct Predictions {
    pub next: Option<String>,
    pub previous: Option<String>,
    pub results: Vec<Prediction>,
}
impl Prediction {
    pub async fn reload(&mut self) -> anyhow::Result<()> {
        let api_key = api_key()?;
        let endpoint = self.urls.get.clone();
        let client = reqwest::Client::new();
        let response = client
            .get(endpoint)
            .header("Authorization", format!("Token {api_key}"))
            .send()
            .await?;
        let data = response.text().await?;
        let prediction: Prediction = serde_json::from_str(data.as_str())?;
        *self = prediction;
        anyhow::Ok(())
    }
    pub async fn get_status(&mut self) -> PredictionStatus {
        self.status.clone()
    }
    pub async fn get_stream(
        &mut self,
    ) -> anyhow::Result<EventStream<impl futures_lite::stream::Stream<Item = reqwest::Result<Bytes>>>>
    {
        if let Some(stream_url) = self.urls.stream.clone() {
            let api_key = api_key()?;
            let client = reqwest::Client::new();
            let stream = client
                .get(stream_url)
                .header("Authorization", format!("Token {api_key}"))
                .header("Accept", "text/event-stream")
                .send()
                .await?
                .bytes_stream()
                .eventsource();
            return anyhow::Ok(stream);
        } else {
            return Err(anyhow!("prediction has no stream url available"));
        }
    }
}
#[derive(Debug)]
pub struct PredictionClient {
    config: ReplicateConfig,
}
#[derive(serde::Serialize)]
struct PredictionInput {
    version: String,
    input: serde_json::Value,
    stream: bool,
}
impl PredictionClient {
    pub fn from(config: ReplicateConfig) -> Self {
        PredictionClient { config }
    }
    pub async fn create(
        &self,
        owner: &str,
        name: &str,
        input: serde_json::Value,
        stream: bool,
    ) -> anyhow::Result<Prediction> {
        let api_key = api_key()?;
        let base_url = base_url();
        let model_client = ModelClient::from(self.config.clone());
        let version = model_client.get_latest_version(owner, name).await?.id;
        let endpoint = format!("{base_url}/predictions");
        let input = PredictionInput {
            version,
            input,
            stream,
        };
        let body = serde_json::to_string(&input)?;
        let client = reqwest::Client::new();
        let response = client
            .post(endpoint)
            .header("Authorization", format!("Token {api_key}"))
            .body(body)
            .send()
            .await?;
        let data = response.text().await?;
        let prediction: Prediction = serde_json::from_str(&data)?;
        anyhow::Ok(prediction)
    }
    pub async fn get(&self, id: String) -> anyhow::Result<Prediction> {
        let api_key = self.config.get_api_key()?;
        let base_url = self.config.get_base_url();
        let endpoint = format!("{base_url}/predictions/{id}");
        let client = reqwest::Client::new();
        let response = client
            .get(endpoint)
            .header("Authorization", format!("Token {api_key}"))
            .send()
            .await?;
        let data = response.text().await?;
        let prediction: Prediction = serde_json::from_str(&data)?;
        anyhow::Ok(prediction)
    }
    pub async fn list(&self) -> anyhow::Result<Predictions> {
        let api_key = self.config.get_api_key()?;
        let base_url = self.config.get_base_url();
        let endpoint = format!("{base_url}/predictions");
        let client = reqwest::Client::new();
        let response = client
            .get(endpoint)
            .header("Authorization", format!("Token {api_key}"))
            .send()
            .await?;
        let data = response.text().await?;
        let predictions: Predictions = serde_json::from_str(&data)?;
        anyhow::Ok(predictions)
    }
    pub async fn cancel(&self, id: String) -> anyhow::Result<Prediction> {
        let api_key = self.config.get_api_key()?;
        let base_url = self.config.get_base_url();
        let endpoint = format!("{base_url}/predictions/{id}/cancel");
        let client = reqwest::Client::new();
        let response = client
            .post(endpoint)
            .header("Authorization", format!("Token {api_key}"))
            .send()
            .await?;
        let data = response.text().await?;
        let prediction: Prediction = serde_json::from_str(&data)?;
        anyhow::Ok(prediction)
    }
}
#[cfg(test)]
mod tests {
    use httpmock::prelude::*;
    use serde_json::json;
    use super::*;
    #[tokio::test]
    async fn test_get() {
        let server = MockServer::start();
        let prediction_mock = server.mock(|when, then| {
            when.method(GET).path("/predictions/1234");
            then.status(200).json_body_obj(&json!(
                {
                    "id": "1234",
                    "model": "replicate/hello-world",
                    "version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
                    "input": {
                        "text": "Alice"
                    },
                    "logs": "",
                    "error": null,
                    "status": "starting",
                    "created_at": "2023-09-08T16:19:34.765994657Z",
                    "urls": {
                        "cancel": "https://api.replicate.com/v1/predictions/1234/cancel",
                        "get": "https://api.replicate.com/v1/predictions/1234"
                    }
                }
            ));
        });
        let client = ReplicateConfig::test(server.base_url()).unwrap();
        let prediction_client = PredictionClient::from(client);
        prediction_client.get("1234".to_string()).await.unwrap();
        prediction_mock.assert();
    }
    #[tokio::test]
    async fn test_create() {
        let server = MockServer::start();
        server.mock(|when, then| {
            when.method(POST).path("/predictions");
            then.status(200).json_body_obj(&json!(
                {
                    "id": "gm3qorzdhgbfurvjtvhg6dckhu",
                    "model": "replicate/hello-world",
                    "version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
                    "input": {
                        "text": "Alice"
                    },
                    "logs": "",
                    "error": null,
                    "status": "starting",
                    "created_at": "2023-09-08T16:19:34.765994657Z",
                    "urls": {
                        "cancel": "https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu/cancel",
                        "get": "https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu"
                    }
                }
            ));
        });
        server.mock(|when, then| {
            when.method(GET)
                .path("/models/replicate/hello-world/versions");
            then.status(200).json_body_obj(&json!({
                "next": null,
                "previous": null,
                "results": [{
                    "id": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
                    "created_at": "2022-04-26T19:29:04.418669Z",
                    "cog_version": "0.3.0",
                    "openapi_schema": null
                }]
            }));
        });
        let client = ReplicateConfig::test(server.base_url()).unwrap();
        let prediction_client = PredictionClient::from(client);
        prediction_client
            .create(
                "replicate",
                "hello-world",
                json!({"text": "This is test input"}),
                false,
            )
            .await
            .unwrap();
    }
    #[tokio::test]
    async fn test_list_predictions() {
        let server = MockServer::start();
        server.mock(|when, then| {
            when.method(GET).path("/predictions");
            then.status(200).json_body_obj(&json!(
                { "next": null,
                  "previous": null,
                  "results": [
                    {
                        "id": "gm3qorzdhgbfurvjtvhg6dckhu",
                        "model": "replicate/hello-world",
                        "version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
                        "input": {
                            "text": "Alice"
                        },
                        "logs": "",
                        "error": null,
                        "status": "starting",
                        "created_at": "2023-09-08T16:19:34.765994657Z",
                        "urls": {
                            "cancel": "https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu/cancel",
                            "get": "https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu"
                        }
                    },
                    {
                        "id": "gm3qorzdhgbfurvjtvhg6dckhu",
                        "model": "replicate/hello-world",
                        "version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
                        "input": {
                            "text": "Alice"
                        },
                        "logs": "",
                        "error": null,
                        "status": "starting",
                        "created_at": "2023-09-08T16:19:34.765994657Z",
                        "urls": {
                            "cancel": "https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu/cancel",
                            "get": "https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu"
                        }
                    }
                ]}
            ));
        });
        let client = ReplicateConfig::test(server.base_url()).unwrap();
        let prediction_client = PredictionClient::from(client);
        prediction_client.list().await.unwrap();
    }
    #[tokio::test]
    async fn test_create_and_reload() {
        let server = MockServer::start();
        server.mock(|when, then| {
            when.method(POST).path("/predictions");
            then.status(200).json_body_obj(&json!(
                {
                    "id": "gm3qorzdhgbfurvjtvhg6dckhu",
                    "model": "replicate/hello-world",
                    "version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
                    "input": {
                        "text": "Alice"
                    },
                    "logs": "",
                    "error": null,
                    "status": "starting",
                    "created_at": "2023-09-08T16:19:34.765994657Z",
                    "urls": {
                        "cancel": "https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu/cancel",
                        "get": "https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu"
                    }
                }
            ));
        });
        server.mock(|when, then| {
            when.method(GET)
                .path("/models/replicate/hello-world/versions");
            then.status(200).json_body_obj(&json!({
                "next": null,
                "previous": null,
                "results": [{
                    "id": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
                    "created_at": "2022-04-26T19:29:04.418669Z",
                    "cog_version": "0.3.0",
                    "openapi_schema": null
                }]
            }));
        });
        let client = ReplicateConfig::test(server.base_url()).unwrap();
        let prediction_client = PredictionClient::from(client);
        let mut prediction = prediction_client
            .create(
                "replicate",
                "hello-world",
                json!({"text": "This is test input"}),
                false,
            )
            .await
            .unwrap();
        prediction.reload().await.unwrap();
    }
    #[tokio::test]
    async fn test_cancel() {
        let server = MockServer::start();
        let prediction_mock = server.mock(|when, then| {
            when.method(POST).path("/predictions/1234/cancel");
            then.status(200).json_body_obj(&json!(
                {
                    "id": "1234",
                    "model": "replicate/hello-world",
                    "version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
                    "input": {
                        "text": "Alice"
                    },
                    "logs": "",
                    "error": null,
                    "status": "starting",
                    "created_at": "2023-09-08T16:19:34.765994657Z",
                    "urls": {
                        "cancel": "https://api.replicate.com/v1/predictions/1234/cancel",
                        "get": "https://api.replicate.com/v1/predictions/1234"
                    }
                }
            ));
        });
        let config = ReplicateConfig::test(server.base_url()).unwrap();
        let prediction_client = PredictionClient::from(config);
        prediction_client.cancel("1234".to_string()).await.unwrap();
        prediction_mock.assert();
    }
}