aleph_alpha_client/
explanation.rs

1use serde::{Deserialize, Serialize};
2
3use crate::{Prompt, Task};
4
5/// Input for a [crate::Client::explanation] request.
6pub struct TaskExplanation<'a> {
7    /// The prompt that typically was the input of a previous completion request
8    pub prompt: Prompt<'a>,
9    /// The target string that should be explained. The influence of individual parts
10    /// of the prompt for generating this target string will be indicated in the response.
11    pub target: &'a str,
12    /// Granularity parameters for the explanation
13    pub granularity: Granularity,
14}
15
16/// Granularity parameters for the [TaskExplanation]
17#[derive(Default)]
18pub struct Granularity {
19    /// The granularity of the parts of the prompt for which a single
20    /// score is computed.
21    prompt: PromptGranularity,
22}
23
24impl Granularity {
25    /// Returns a new [Granularity] based on the given one with the [Granularity::prompt]
26    /// being set to `prompt_granularity`.
27    pub fn with_prompt_granularity(self, prompt_granularity: PromptGranularity) -> Self {
28        Self {
29            prompt: prompt_granularity,
30        }
31    }
32}
33
34/// At which granularity should the target be explained in terms of the prompt.
35/// If you choose, for example, [PromptGranularity::Sentence] then we report the importance score of each
36/// sentence in the prompt towards generating the target output.
37/// The default is [PromptGranularity::Auto] which means we will try to find the granularity that
38/// brings you closest to around 30 explanations. For large prompts, this would likely
39/// be sentences. For short prompts this might be individual words or even tokens.
40#[derive(Serialize, Copy, Clone, PartialEq, Default)]
41#[serde(tag = "type", rename_all = "snake_case")]
42pub enum PromptGranularity {
43    /// Let the system decide which granularity is most suitable for the given input.
44    /// This will result
45    #[default]
46    Auto,
47    Word,
48    Sentence,
49    Paragraph,
50}
51
52impl PromptGranularity {
53    fn is_auto(&self) -> bool {
54        self == &PromptGranularity::Auto
55    }
56}
57
58/// Body sent to the Aleph Alpha API for an explanation request
59#[derive(Serialize)]
60struct BodyExplanation<'a> {
61    prompt: Prompt<'a>,
62    target: &'a str,
63    #[serde(skip_serializing_if = "PromptGranularity::is_auto")]
64    prompt_granularity: PromptGranularity,
65    model: &'a str,
66}
67
68/// Body received by the Aleph Alpha API from an explanation request
69#[derive(Deserialize, Debug, PartialEq)]
70pub struct ResponseExplanation {
71    /// The Body contains an array of [Explanation]s one for each
72    /// part of the target being explained.
73    explanations: Vec<Explanation>,
74}
75
76/// The result of an explanation request.
77#[derive(Debug, PartialEq)]
78pub struct ExplanationOutput {
79    /// Explanation scores for different parts of the prompt or target.
80    pub items: Vec<ItemExplanation>,
81}
82
83impl ExplanationOutput {
84    fn from(mut response: ResponseExplanation) -> ExplanationOutput {
85        ExplanationOutput {
86            items: response.explanations.pop().unwrap().items,
87        }
88    }
89}
90
91/// The explanation for the target.
92#[derive(Debug, Deserialize, PartialEq)]
93pub struct Explanation {
94    /// Explanation scores for different parts of the prompt or target.
95    pub items: Vec<ItemExplanation>,
96}
97
98/// Explanation scores for a [crate::prompt::Modality] or the target.
99/// There is one score
100/// for each part of a `modality` respectively the target with the parts being choosen according to
101/// the [PromptGranularity]
102#[derive(PartialEq, Deserialize, Debug)]
103#[serde(tag = "type", rename_all = "snake_case")]
104pub enum ItemExplanation {
105    Text { scores: Vec<TextScore> },
106    Image { scores: Vec<ImageScore> },
107    Target { scores: Vec<TextScore> },
108}
109
110/// Score for the part of a text-modality
111#[derive(Debug, PartialEq, Deserialize, Clone)]
112pub struct TextScore {
113    pub start: u32,
114    pub length: u32,
115    pub score: f32,
116}
117
118/// Resembles the actual response from the API for deserialization
119#[derive(Deserialize)]
120struct BoundingBox {
121    top: f32,
122    left: f32,
123    width: f32,
124    height: f32,
125}
126
127/// Resembles the actual response from the API for deserialization
128#[derive(Deserialize)]
129struct ImageScoreWithRect {
130    rect: BoundingBox,
131    score: f32,
132}
133
134/// Score for a part of an image.
135#[derive(Debug, PartialEq, Deserialize, Clone)]
136#[serde(from = "ImageScoreWithRect")]
137pub struct ImageScore {
138    pub left: f32,
139    pub top: f32,
140    pub width: f32,
141    pub height: f32,
142    pub score: f32,
143}
144
145impl From<ImageScoreWithRect> for ImageScore {
146    fn from(value: ImageScoreWithRect) -> Self {
147        Self {
148            left: value.rect.left,
149            top: value.rect.top,
150            width: value.rect.width,
151            height: value.rect.height,
152            score: value.score,
153        }
154    }
155}
156
157impl Task for TaskExplanation<'_> {
158    type Output = ExplanationOutput;
159
160    type ResponseBody = ResponseExplanation;
161
162    fn build_request(
163        &self,
164        client: &reqwest::Client,
165        base: &str,
166        model: &str,
167    ) -> reqwest::RequestBuilder {
168        let body = BodyExplanation {
169            model,
170            prompt: self.prompt.borrow(),
171            target: self.target,
172            prompt_granularity: self.granularity.prompt,
173        };
174        client.post(format!("{base}/explain")).json(&body)
175    }
176
177    fn body_to_output(&self, response: Self::ResponseBody) -> Self::Output {
178        ExplanationOutput::from(response)
179    }
180}